MLBEDSW-4034: New Scheduler Size or Performance Optimisation

 - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
diff --git a/ethosu/vela/architecture_allocator.py b/ethosu/vela/architecture_allocator.py
new file mode 100644
index 0000000..c308a4a
--- /dev/null
+++ b/ethosu/vela/architecture_allocator.py
@@ -0,0 +1,389 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description: Architecture SHRAM allocator
+import enum
+import math
+from typing import Optional
+from typing import Tuple
+
+from .architecture_features import ArchitectureFeatures
+from .architecture_features import Block
+from .architecture_features import SHRAMConfig
+from .architecture_features import SHRAMElements
+from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .numeric_util import round_up
+from .numeric_util import round_up_divide
+from .operation import Kernel
+from .operation import NpuBlockType
+from .range_set import MemoryRangeSet
+from .shape4d import Shape4D
+from .tensor import MemArea
+
+
+class SHRAMLayout:
+    def __init__(self):
+        self.ib_start = 0
+        self.ib_end = 0
+        self.ib_start2 = 0
+        self.ab_start = 0
+        self.lut_start = 0
+
+
+class ArchitectureBlockConfig:
+    def __init__(self):
+        self.layout = SHRAMLayout()
+        self.ifm_block = Shape4D()
+        self.ofm_block = Shape4D()
+        self.acc_type = SHRAMElements.Acc32
+        self.is_partkernel = False
+        self.bank_size = 0
+
+    def get_shram_memory_access_range(self):
+        # Returns the SHRAM memory access range used by this shared buffer,
+        # excluding access to LUT
+        return MemoryRangeSet(MemArea.Shram, 0, self.layout.lut_start * self.bank_size)
+
+    def old_style_representation(self):
+        return [self.ofm_block.height, self.ofm_block.width, self.ifm_block.depth, self.ofm_block.depth]
+
+    def __str__(self):
+        return str(self.old_style_representation())
+
+
+_AccumulatorBits = {SHRAMElements.Acc16: 16, SHRAMElements.Acc32: 32, SHRAMElements.Acc40: 40}
+
+
+class ElementwiseUsage(enum.IntEnum):
+    No = 0
+    Full = 1
+    Scalar = 2
+
+
+def _try_block_config(
+    shram: SHRAMConfig,
+    ew_usage: ElementwiseUsage,
+    ofm_block: Block,
+    ifm_block: Block,
+    ifm_bits: int,
+    ifm_granule: int,
+    acc_bits: int,
+    acc_granule: int,
+    lut_banks: int,
+) -> SHRAMLayout:
+    assert (acc_bits > 0) and (acc_granule > 0)
+    assert (ifm_bits >= 8) and ((ifm_bits % 8) == 0) and (ifm_granule > 0)
+
+    # Aways need IFM space
+    ifm_bytes = ifm_block.elements_wh() * round_up((ifm_block.depth * ifm_bits) / 8, 8)
+    ifm_banks = round_up_divide(ifm_bytes, shram.bank_size_bytes) * 2
+    ifm_banks = round_up(ifm_banks, ifm_granule)
+
+    # Calculate SHRAM boundaries of the IFM and Accumulators
+    lut_start = shram.total_banks - lut_banks
+    ifm_end = shram.reserved_output_banks + ifm_banks
+    ifm2_start = ifm_end
+    acc_start = lut_start
+
+    # If not elementwise then we need accumulator space
+    if ew_usage == ElementwiseUsage.No:
+        acc_bytes = (ofm_block.elements_wh() * round_up(ofm_block.depth, 8) * acc_bits) // 8
+        acc_banks = round_up_divide(acc_bytes, shram.bank_size_bytes) * 2
+        acc_banks = round_up(acc_banks, acc_granule)
+        acc_start = acc_start - acc_banks
+    else:
+        ifm2_banks = ifm_banks if ew_usage == ElementwiseUsage.Full else 0
+        if ifm2_start + ifm2_banks > acc_start:
+            return None
+        ifm_end = acc_start
+
+    # IFM must still fit before accumulators
+    if ifm_end > acc_start:
+        return None
+
+    # Should all fit, so return this layout
+    layout = SHRAMLayout()
+    layout.ib_start = shram.reserved_output_banks
+    layout.ib_start2 = ifm2_start
+    layout.ib_end = ifm_end
+    layout.ab_start = acc_start
+    layout.lut_start = lut_start
+    return layout
+
+
+def _choose_kernel_method(ifm_shape: Shape4D, ifm_bits: int, kernel: Kernel) -> bool:
+    if ifm_shape.depth <= 8:
+        return True
+
+    # Compare part-kernel to depth-kernel and choose the one with best utilisation
+    kernel_elements = kernel.elements_wh()
+    depth_utilisation = ifm_shape.depth / round_up(ifm_shape.depth, 32 if ifm_bits == 8 else 16)
+    part_utilisation = (
+        ifm_shape.depth
+        * kernel_elements
+        / (round_up(ifm_shape.depth, 8) * round_up(kernel_elements, 4 if ifm_bits == 8 else 2))
+    )
+
+    return part_utilisation > depth_utilisation
+
+
+def _ew_usage(npu_op_type: NpuBlockType, uses_scalar: bool) -> ElementwiseUsage:
+    ew_usage = ElementwiseUsage.No
+    if npu_op_type == NpuBlockType.ElementWise:
+        ew_usage = ElementwiseUsage.Full
+        if uses_scalar:
+            ew_usage = ElementwiseUsage.Scalar
+    return ew_usage
+
+
+def _acc_type(npu_op_type: NpuBlockType, ifm_bits: int, scaled: bool) -> int:
+    """Returns accumulator type"""
+    acc_type = SHRAMElements.Acc32
+    if (ifm_bits == 16) and npu_op_type != NpuBlockType.Pooling and scaled:
+        acc_type = SHRAMElements.Acc40
+    return acc_type
+
+
+def to_upscale(ifm_resampling: resampling_mode) -> int:
+    # Upscaling depending on resampling mode
+    return 1 if ifm_resampling == resampling_mode.NONE else 2
+
+
+def _ifm_blockdepth(arch, ifm_shape: Shape4D, ifm_bits: int, is_partkernel: bool):
+    if ifm_bits == 16:
+        ifm_blockdepth = round_up(min(ifm_shape.depth, 16), 4)
+    else:
+        ifm_blockdepth = round_up(min(ifm_shape.depth, 16 if is_partkernel else 32), arch.ifm_ublock.depth)
+    return ifm_blockdepth
+
+
+def _required_size(value: int, stride: int, border: int, upscale: int) -> int:
+    return int(math.ceil(((value - 1) * stride + border) / upscale))
+
+
+def get_ifm_area_required(ofm_shape: Shape4D, kernel: Kernel, upscale: int) -> Tuple[int, int]:
+    h1 = _required_size(ofm_shape.height, kernel.stride.y, kernel.area_height(), upscale)
+    w1 = _required_size(ofm_shape.width, kernel.stride.x, kernel.area_width(), upscale)
+    return (w1, h1)
+
+
+def _get_ifm_blocksize(
+    ofm_block: Shape4D, kernel: Kernel, ublock: Block, subkernel_limit: Block, upscale: int
+) -> Shape4D:
+    # IFM block height
+    h1 = _required_size(ofm_block.height, kernel.stride.y, min(kernel.area_height(), subkernel_limit.height), upscale)
+    h2 = h1
+    height = round_up(min(h1, h2), ublock.height)
+
+    # IFM block width
+    w1 = _required_size(ofm_block.width, kernel.stride.x, min(kernel.area_width(), subkernel_limit.width), upscale)
+    w2 = w1
+    width = round_up(min(w1, w2), ublock.width)
+
+    return Shape4D(1, height, width, ofm_block.depth)
+
+
+def find_block_config(
+    arch: ArchitectureFeatures,
+    npu_op_type: NpuBlockType,
+    ofm_shape: Shape4D,
+    ifm_shape: Shape4D,
+    ifm2_shape: Shape4D,
+    uses_scalar: bool,
+    ifm_bits: int,
+    kernel: Kernel,
+    lut_banks: int,
+    scaled: bool,
+    ifm_resampling: resampling_mode,
+) -> ArchitectureBlockConfig:
+    SplitDepth = ArchitectureFeatures.OFMSplitDepth
+    # Elementwise larger-volume correction
+    if ifm2_shape is not None and ifm2_shape.elements() > ifm_shape.elements():
+        ifm_shape = ifm2_shape
+
+    # Figure out if SHRAM should be portioned for elementwise
+    ew_usage = _ew_usage(npu_op_type, uses_scalar)
+
+    # Operator typing help
+    is_pooling = npu_op_type == NpuBlockType.Pooling
+    is_depthwise = npu_op_type == NpuBlockType.ConvolutionDepthWise
+    is_equal_depth_op = (ew_usage != ElementwiseUsage.No) or is_pooling or is_depthwise
+    is_convolution = (npu_op_type == NpuBlockType.ConvolutionMxN) or is_depthwise
+
+    # Block config to be returned
+    config = ArchitectureBlockConfig()
+    config.is_partkernel = is_convolution and _choose_kernel_method(ifm_shape, ifm_bits, kernel)
+
+    # Accumulator & granule settings
+    config.acc_type = _acc_type(npu_op_type, ifm_bits, scaled)
+
+    # Memory rounding granules
+    acc_granule = arch.accumulator_granules[config.acc_type]
+    acc_bits = _AccumulatorBits[config.acc_type]
+    if ew_usage != ElementwiseUsage.No:
+        ifm_granule = arch.ifm_ew_bank_granules[ifm_bits]
+    else:
+        ifm_granule = arch.ifm_bank_granules[ifm_bits]
+    lut_banks = max(lut_banks, arch.shram.reserved_end_banks)
+    upscale = to_upscale(ifm_resampling)
+
+    # Subkernel repeats of the IFM
+    ifm_repeats = round_up_divide(kernel.area_width(), arch.SubKernelMax.width) * round_up_divide(
+        kernel.area_height(), arch.SubKernelMax.height
+    )
+    ifm_blockdepth = _ifm_blockdepth(arch, ifm_shape, ifm_bits, config.is_partkernel)
+
+    # Weights fetch (for operators that have them)
+    weight_fetch_wh = (kernel.area_width() * kernel.area_height()) if is_convolution else 0
+
+    search_space = Shape4D.min(ofm_shape, Shape4D(arch.ofm_block_max.to_hwc()))
+    search_space = Shape4D.round_up(search_space, Shape4D(arch.ofm_ublock.to_hwc()))
+
+    # Block WHC search, loops across the search space looking for best efficiency
+    best_cost = math.inf
+    depth = max(arch.ofm_ublock.depth, min(search_space.depth, SplitDepth))
+    if depth < ofm_shape.depth:
+        depth = round_up(depth, SplitDepth)
+
+    while depth <= search_space.depth:
+        wont_fit = {}
+        for height in range(arch.ofm_ublock.height, search_space.height + 1, arch.ofm_ublock.height):
+            for width in range(arch.ofm_ublock.width, search_space.width + 1, arch.ofm_ublock.width):
+                # Avoid checking W/H transposed blocks that already didn't fit. i.e. if 8x4x16 didn't
+                # fit, then 4x8x16 won't either.
+                if wont_fit.get((height, width), False):
+                    continue
+
+                # Calculate the IFM block dimensions required to feed this OFM block
+                ofm_block = Shape4D(1, height, width, depth)
+                ifm_block = _get_ifm_blocksize(ofm_block, kernel, arch.ofm_ublock, arch.SubKernelMax, upscale)
+                if not is_equal_depth_op:
+                    ifm_block = ifm_block.with_depth(ifm_blockdepth)
+
+                # Test if the IFM/OFM blocks fit into SHRAM
+                layout = _try_block_config(
+                    arch.shram, ew_usage, ofm_block, ifm_block, ifm_bits, ifm_granule, acc_bits, acc_granule, lut_banks
+                )
+
+                if layout:
+                    # Calculate cost in terms of OFM pixels per IFM+Weights fetch
+                    ifm_fetch = ifm_block.elements_wh() * ifm_shape.depth
+                    weight_fetch = weight_fetch_wh * ifm_shape.depth * (1 if is_depthwise else ofm_block.depth)
+                    relative_fetch = (ifm_fetch * ifm_repeats + weight_fetch) / ofm_block.elements()
+
+                    # Bias by the number of blocks we'd need to fill the OFM area (fewer, larger, blocks are better)
+                    block_bias = round_up_divide(ofm_shape.height, ofm_block.height)
+                    block_bias *= round_up_divide(ofm_shape.width, ofm_block.width)
+                    # Check waste on all axes (prefer depth, width then height)
+                    waste_ratio = 1 + (1.2 * ((ofm_shape.depth % ofm_block.depth) / ofm_block.depth))
+                    waste_ratio *= 1 + (1.1 * ((ofm_shape.width % ofm_block.width) / ofm_block.width))
+                    waste_ratio *= 1 + (1.0 * ((ofm_shape.height % ofm_block.height) / ofm_block.height))
+
+                    # Bias for larger area coverage (or volume if not depthwise)
+                    area_bias = 1 / (ofm_block.height * ofm_block.width)
+                    if not (is_depthwise or is_pooling):
+                        area_bias = area_bias / ofm_block.depth
+
+                    relative_cost = relative_fetch * block_bias * waste_ratio * area_bias
+
+                    # If the entire IFM can be encompassed by both buffers, bias to prefer this configuration
+                    if ifm_shape.elements() < ifm_block.elements() * 2:
+                        relative_cost = relative_cost / 2
+
+                    if relative_cost < best_cost:
+                        best_cost = relative_cost
+                        config.layout = layout
+                        config.bank_size = arch.shram_bank_size
+                        config.ifm_block = ifm_block
+                        config.ofm_block = ofm_block
+                else:
+                    wont_fit[(width, height)] = True
+
+        depth = depth + arch.ofm_ublock.depth
+        if depth < ofm_shape.depth:
+            depth = round_up(depth, SplitDepth)
+
+    if best_cost != math.inf:
+        return config
+
+    return None
+
+
+def try_block_config(
+    block_config: Block,
+    arch: ArchitectureFeatures,
+    npu_op_type: NpuBlockType,
+    ifm_shape: Block,
+    ifm2_shape: Optional[Block],
+    uses_scalar: bool,
+    ifm_bits: int,
+    is_partkernel: bool,
+    kernel: Kernel,
+    lut_banks: int,
+    scaled: bool,
+    ifm_resampling: resampling_mode,
+) -> Optional[ArchitectureBlockConfig]:
+    """
+    Given a block_config, returns a corresponding ArchitectureBlockConfig.
+    Returns None if the block_config does not fit or is invalid.
+    """
+    # Check block config validity
+    if not all(
+        blk > 0 and blk <= blk_max and blk % ublk == 0
+        for blk, blk_max, ublk in zip(block_config.as_list(), arch.ofm_block_max.as_list(), arch.ofm_ublock.as_list())
+    ):
+        return None
+    # Elementwise larger-volume correction
+    if ifm2_shape is not None and ifm2_shape.elements() > ifm_shape.elements():
+        ifm_shape = ifm2_shape
+
+    ew_usage = _ew_usage(npu_op_type, uses_scalar)
+
+    # Operator typing help
+    is_pooling = npu_op_type == NpuBlockType.Pooling
+    is_depthwise = npu_op_type == NpuBlockType.ConvolutionDepthWise
+    is_equal_depth_op = (ew_usage != ElementwiseUsage.No) or is_pooling or is_depthwise
+
+    # Block config to be returned
+    config = ArchitectureBlockConfig()
+    config.is_partkernel = is_partkernel
+
+    # Accumulator & granule settings
+    config.acc_type = _acc_type(npu_op_type, ifm_bits, scaled)
+
+    # Memory rounding granules
+    acc_granule = arch.accumulator_granules[config.acc_type]
+    acc_bits = _AccumulatorBits[config.acc_type]
+    if ew_usage != ElementwiseUsage.No:
+        ifm_granule = arch.ifm_ew_bank_granules[ifm_bits]
+    else:
+        ifm_granule = arch.ifm_bank_granules[ifm_bits]
+    lut_banks = max(lut_banks, arch.shram.reserved_end_banks)
+    upscale = to_upscale(ifm_resampling)
+    ifm_blockdepth = _ifm_blockdepth(arch, ifm_shape, ifm_bits, is_partkernel)
+    ifm_block = _get_ifm_blocksize(block_config, kernel, arch.ofm_ublock, arch.SubKernelMax, upscale)
+    if not is_equal_depth_op:
+        ifm_block = ifm_block.with_depth(ifm_blockdepth)
+
+    layout = _try_block_config(
+        arch.shram, ew_usage, block_config, ifm_block, ifm_bits, ifm_granule, acc_bits, acc_granule, lut_banks
+    )
+    if layout is None:
+        return None
+    config.layout = layout
+    config.bank_size = arch.shram_bank_size
+    config.ifm_block = ifm_block
+    return config
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 0521985..19133f5 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -41,11 +41,23 @@
 
 
 class Block:
-    def __init__(self, w, h, d):
+    def __init__(self, w=0, h=0, d=0):
         self.width = w
         self.height = h
         self.depth = d
 
+    def elements(self):
+        return self.width * self.height * self.depth
+
+    def elements_wh(self):
+        return self.width * self.height
+
+    def clone(self):
+        return Block(self.width, self.height, self.depth)
+
+    def as_list(self):
+        return [self.height, self.width, self.depth]
+
     def __eq__(self, other):
         if self.width == other.width and self.height == other.height and self.depth == other.depth:
             return True
@@ -55,6 +67,9 @@
     def __repr__(self):
         return "<Block: {0},{1},{2}>".format(self.width, self.height, self.depth)
 
+    def to_hwc(self):
+        return [self.height, self.width, self.depth]
+
     @classmethod
     def from_string(cls, s):
         w, h, c = (int(v) for v in s.split("x"))
@@ -67,6 +82,24 @@
         # Note: index from end, as len(shp) may be > 3
         return Block(shp[-2], shp[-3], shp[-1])
 
+    @classmethod
+    def min(cls, a, b):
+        return cls(min(a.width, b.width), min(a.height, b.height), min(a.depth, b.depth))
+
+    @classmethod
+    def max(cls, a, b):
+        return cls(max(a.width, b.width), max(a.height, b.height), max(a.depth, b.depth))
+
+    @classmethod
+    def round(cls, a, b):
+        return cls(round_up(a.width, b.width), round_up(a.height, b.height), round_up(a.depth, b.depth))
+
+    @classmethod
+    def div_round_up(cls, a, b):
+        return cls(
+            round_up_divide(a.width, b.width), round_up_divide(a.height, b.height), round_up_divide(a.depth, b.depth)
+        )
+
 
 class Rect:
     def __init__(self, x, y, z, x2, y2, z2):
@@ -155,6 +188,11 @@
     Axi1 = enum.auto()
 
 
+SHRAMConfig = namedtuple(
+    "SHRAMConfig", ["reserved_output_banks", "bank_size_bytes", "total_banks", "reserved_end_banks"]
+)
+
+
 class ArchitectureFeatures:
     """This class is a container for various parameters of the Ethos-U core
     and system configuration that can be tuned, either by command line
@@ -202,11 +240,9 @@
         accelerator_config,
         system_config,
         memory_mode,
-        override_block_config,
-        block_config_limit,
         max_blockdep,
-        weight_estimation_scaling,
         verbose_config,
+        arena_cache_size,
     ):
         accelerator_config = accelerator_config.lower()
         if accelerator_config not in Accelerator.member_list():
@@ -215,6 +251,26 @@
         accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config]
         self.config = accel_config
 
+        self.accumulator_granules = {
+            SHRAMElements.Acc16: accel_config.shram_granules[SHRAMElements.Acc16],
+            SHRAMElements.Acc32: accel_config.shram_granules[SHRAMElements.Acc32],
+            SHRAMElements.Acc40: accel_config.shram_granules[SHRAMElements.Acc40],
+        }
+
+        self.ifm_bank_granules = {
+            8: accel_config.shram_granules[SHRAMElements.IFM8],
+            16: accel_config.shram_granules[SHRAMElements.IFM16],
+            32: accel_config.shram_granules[SHRAMElements.IFM32],
+        }
+
+        self.ifm_ew_bank_granules = {
+            8: accel_config.shram_granules[SHRAMElements.IFM8_Elementwise],
+            16: accel_config.shram_granules[SHRAMElements.IFM16_Elementwise],
+            32: accel_config.shram_granules[SHRAMElements.IFM32],
+        }
+
+        self.shram = SHRAMConfig(2, 1024, accel_config.shram_banks, 2 if accel_config.shram_banks > 16 else 0)
+
         self.system_config = system_config
         self.memory_mode = memory_mode
         self.is_ethos_u65_system = self.accelerator_config in (Accelerator.Ethos_U65_256, Accelerator.Ethos_U65_512)
@@ -226,11 +282,8 @@
         self.ofm_ublock = accel_config.ofm_ublock
         self.ifm_ublock = accel_config.ifm_ublock
         self.ofm_block_max = Block(64, 32, 128)
-        self.override_block_config = override_block_config
-        self.block_config_limit = block_config_limit
 
         self.max_blockdep = max_blockdep
-        self.weight_estimation_scaling = weight_estimation_scaling
 
         dpu_min_height = accel_config.ofm_ublock.height
         dpu_min_width = accel_config.ofm_ublock.width
@@ -243,7 +296,7 @@
         self.max_address_offset = 1 << 48 if self.is_ethos_u65_system else 1 << 32
 
         # Get system configuration and memory mode
-        self._get_vela_config(vela_config_files, verbose_config)
+        self._get_vela_config(vela_config_files, verbose_config, arena_cache_size)
 
         self.axi_port_width = 128 if self.is_ethos_u65_system else 64
         self.memory_bandwidths_per_cycle = self.axi_port_width * self.memory_clock_scales / 8
@@ -341,7 +394,7 @@
         # IFM/OFM block size.
         ifm_block_max = self.get_ifm_block_size(32, self.ofm_block_max, Kernel(8, 8))
         self.block_config_map = dict()
-        self.generate_block_config_map(Block(ifm_block_max.width, ifm_block_max.height, 128))
+        self.generate_block_config_map(Block(ifm_block_max.width * 2, ifm_block_max.height, 128))
 
         # Setup supported operators and restriction checkers class
         self.supported_operators = SupportedOperators()
@@ -457,7 +510,7 @@
     def mem_type_size(self, mem_type: MemType) -> int:
         """Returns size in bytes available for the given memory type"""
         if mem_type == MemType.Scratch_fast and self.is_spilling_enabled():
-            return self.sram_size
+            return self.arena_cache_size
         # Size is unknown, return max possible address offset
         return self.max_address_offset
 
@@ -505,7 +558,7 @@
             self.const_mem_area = MemPort.Axi1
             self.arena_mem_area = MemPort.Axi1
             self.cache_mem_area = MemPort.Axi0
-            self.cache_sram_size = 384 * 1024
+            self.arena_cache_size = 384 * 1024
         else:
             # Default Ethos-U55 memory mode
             # Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
@@ -513,8 +566,9 @@
             self.const_mem_area = MemPort.Axi1
             self.arena_mem_area = MemPort.Axi0
             self.cache_mem_area = MemPort.Axi0
+            self.arena_cache_size = self.max_address_offset
 
-    def _get_vela_config(self, vela_config_files, verbose_config):
+    def _get_vela_config(self, vela_config_files, verbose_config, arena_cache_size_from_cli):
         """
         Gets the system configuration and memory modes from one or more Vela configuration file(s) or uses some
         defaults.
@@ -530,7 +584,8 @@
         self.const_mem_area = MemPort(1)
         self.arena_mem_area = MemPort(1)
         self.cache_mem_area = MemPort(1)
-        self.cache_sram_size = 1
+        self.arena_cache_size = self.max_address_offset
+        arena_cache_size_loc_text = "Default"
 
         # read configuration file(s)
         self.vela_config = None
@@ -596,12 +651,12 @@
             self.cache_mem_area = MemPort[
                 self._read_config(mem_mode_section, "cache_mem_area", self.cache_mem_area.name)
             ]
-            self.cache_sram_size = int(self._read_config(mem_mode_section, "cache_sram_size", self.cache_sram_size))
-            if self.cache_sram_size > self.max_address_offset:
-                raise ConfigOptionError(
-                    "cache_sram_size",
-                    f"{self.cache_sram_size}. Size is out of bounds, maximum is: {self.max_address_offset}",
-                )
+            found = []
+            self.arena_cache_size = int(
+                self._read_config(mem_mode_section, "arena_cache_size", self.arena_cache_size, found)
+            )
+            if found[-1]:
+                arena_cache_size_loc_text = "Configuration file"
 
         elif self.memory_mode == ArchitectureFeatures.DEFAULT_CONFIG:
             self._set_default_mem_mode()
@@ -631,6 +686,11 @@
                 self.memory_burst_length[MemArea.OnChipFlash] = self.memory_burst_length[MemArea.Sram]
                 self.memory_latency[MemArea.OnChipFlash] = self.memory_latency[MemArea.Sram]
 
+        # override sram usage
+        if arena_cache_size_from_cli is not None:
+            self.arena_cache_size = arena_cache_size_from_cli
+            arena_cache_size_loc_text = "CLI option"
+
         # check configuration
         if self._mem_port_mapping(self.const_mem_area) not in (
             MemArea.Dram,
@@ -649,13 +709,19 @@
         if self._mem_port_mapping(self.cache_mem_area) != MemArea.Sram:
             raise ConfigOptionError("cache_mem_area", self._mem_port_mapping(self.cache_mem_area).name, "Sram")
 
+        if self.arena_cache_size < 0:
+            raise ConfigOptionError("arena_cache_size", self.arena_cache_size, ">= 0")
+        if self.arena_cache_size > self.max_address_offset:
+            raise ConfigOptionError(
+                "arena_cache_size",
+                f"{self.arena_cache_size}. Size is out of bounds, maximum is: {self.max_address_offset}",
+            )
+
         # assign existing memory areas
         self.permanent_storage_mem_area = self._mem_port_mapping(self.const_mem_area)
         self.feature_map_storage_mem_area = self._mem_port_mapping(self.arena_mem_area)
         self.fast_storage_mem_area = self._mem_port_mapping(self.cache_mem_area)
 
-        self.sram_size = self.cache_sram_size if self.is_spilling_enabled() else 9999 * 1024 * 1024
-
         # display the system configuration and memory mode
         if verbose_config:
             print(f"System Configuration ({self.system_config}):")
@@ -672,24 +738,28 @@
             print(f"   const_mem_area = {self.const_mem_area.name}")
             print(f"   arena_mem_area = {self.arena_mem_area.name}")
             print(f"   cache_mem_area = {self.cache_mem_area.name}")
-            print(f"   cache_sram_size = {self.cache_sram_size}")
+            print(f"   arena_cache_size = {self.arena_cache_size} from {arena_cache_size_loc_text}")
 
             print("Architecture Settings:")
             print(f"   permanent_storage_mem_area = {self.permanent_storage_mem_area.name}")
             print(f"   feature_map_storage_mem_area = {self.feature_map_storage_mem_area.name}")
             print(f"   fast_storage_mem_area = {self.fast_storage_mem_area.name}")
-            print(f"   sram_size = {self.sram_size}")
 
-    def _read_config(self, section, key, current_value):
+    def _read_config(self, section, key, current_value, found=None):
         """
         Reads a given key from a particular section in the Vela config file. If the section contains the 'inherit'
         option then we recurse into the section specified. If inherited sections result in multiple keys for a
-        particular option then the key from the parent section is used, regardless of the parsing order
+        particular option then the key from the parent section is used, regardless of the parsing order. if specified
+        found should be an empty list that this function will append a True or False to the end of the list indicating
+        whether the key was found or not.
         """
         if not self.vela_config.has_section(section):
             raise ConfigOptionError("section", f"{section}. The section was not found in the Vela config file(s)")
 
-        result = str(current_value)
+        result = str(current_value) if current_value is not None else None
+        if found is not None:
+            found.append(False)
+
         if self.vela_config.has_option(section, "inherit"):
             inheritance_section = self.vela_config.get(section, "inherit")
             # check for recursion loop
@@ -697,10 +767,12 @@
                 raise ConfigOptionError(
                     "inherit", f"{inheritance_section}. This references its own section and recursion is not allowed",
                 )
-            result = self._read_config(inheritance_section, key, result)
+            result = self._read_config(inheritance_section, key, result, found)
 
         if self.vela_config.has_option(section, key):
             result = self.vela_config.get(section, key)
+            if found is not None:
+                found.append(True)
 
         return result
 
@@ -717,10 +789,8 @@
             accelerator_config=accelerator.value,
             system_config=ArchitectureFeatures.DEFAULT_CONFIG,
             memory_mode=ArchitectureFeatures.DEFAULT_CONFIG,
-            override_block_config=None,
-            block_config_limit=None,
             max_blockdep=ArchitectureFeatures.MAX_BLOCKDEP,
-            weight_estimation_scaling=1.0,
             verbose_config=False,
+            arena_cache_size=None,
         )
     return default_arch_cache[accelerator]
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
new file mode 100644
index 0000000..e4fa67e
--- /dev/null
+++ b/ethosu/vela/cascade_builder.py
@@ -0,0 +1,260 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Groups Operators in a schedule together to form Cascades.
+from .numeric_util import round_up
+from .operation import NpuBlockType
+from .shape4d import Shape4D
+
+non_cascadable_blocks = (
+    NpuBlockType.Default,
+    NpuBlockType.VectorProduct,
+    NpuBlockType.ElementWise,
+    NpuBlockType.ReduceSum,
+)
+
+
+class CascadeInfo:
+    """Contains metadata about a cascade"""
+
+    def __init__(self, start, end, buffers, mem_usage: int):
+        self.start = start
+        self.end = end
+        self.buffers = buffers
+        self.mem_usage = mem_usage
+
+
+class BufferMap:
+    """Caches the buffers seen"""
+
+    def __init__(self):
+        self.buffer_map = {}
+
+    def get_buffer(self, producer, consumer, cost):
+        assert producer or consumer
+        key = (producer, consumer)
+        if key not in self.buffer_map:
+            # No cached buffer between these two SchedulerOperations
+            if consumer is None:
+                # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
+                buffer_shape = producer.ofm.shape
+                buffer_size = producer.ofm_size_in_bytes()
+            elif producer is None:
+                # First Op in subgraph or cascade - FeatureMap needs to be stored in full
+                buffer_shape = consumer.ifm.shape
+                buffer_size = consumer.ifm_size_in_bytes()
+            elif producer.requires_full_ofm or consumer.requires_full_ifm:
+                # FeatureMap needs to be stored in full
+                buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
+                buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
+            else:
+                # Use a rolling buffer
+                buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
+                buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
+
+            self.buffer_map[key] = (buffer_shape, buffer_size)
+
+        return self.buffer_map[key]
+
+
+def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
+    """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
+    buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
+    # Rolling buffers have to conform to NHCWB16 format
+    return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
+
+
+class CascadeBuilder:
+    """Class for grouping SchedulerOperations into cascades"""
+
+    def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
+        self.sched_ops = sched_ops
+        self.no_cascade = 0
+        self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
+        self.spilling = spilling
+
+    def _is_cascadable(self, sched_op, cost) -> bool:
+        """Checks if 'sched_op' can be cascaded"""
+        return (
+            sched_op.op_type.npu_block_type not in non_cascadable_blocks
+            and cost.stripe.height < sched_op.ofm.shape.height
+        )
+
+    def _estimate_sram_usage(self, sched_op, cost) -> int:
+        """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
+        ifm2_size = sched_op.ifm2_size_in_bytes()
+        if sched_op.requires_full_ifm:
+            ifm_size = sched_op.ifm_size_in_bytes()
+        else:
+            ifm_size = (
+                cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
+                * sched_op.ifm.dtype.size_in_bytes()
+            )
+        if sched_op.requires_full_ofm:
+            ofm_size = sched_op.ofm_size_in_bytes()
+        else:
+            ofm_size = (
+                cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
+            )
+
+        return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
+
+    def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
+        ref_cost = ref_schedule.cost_map
+        fallback_cost = fallback_schedule.cost_map
+        cost = {}
+        cascade_map = {}
+        buffers = BufferMap()
+
+        # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
+        peak_sram_usage = guiding_mem_limit
+
+        idx = 0
+        while idx < len(self.sched_ops):
+            op = self.sched_ops[idx]
+            if op in cost:
+                # Already processed this Op
+                idx += 1
+                continue
+
+            if not self._is_cascadable(op, ref_cost[op]):
+                # Op is not a candidate for cascading - assign fallback cost
+                cost[op] = fallback_cost[op]
+                if not self.spilling:
+                    peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
+                idx += 1
+                continue
+
+            # Propose a cascade starting with this Op
+            cascade_start = op.index
+            # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
+            ops_in_cascade = [op]
+            ops_in_best_cascade = [op]
+            # Get the size of the weight buffer
+            weight_buffer = 0
+            if ref_cost[op].buffered_weight_tensor:
+                weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+
+            # The first IFM needs to be stored in full
+            cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
+
+            # Add non-local memory usage
+            cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
+
+            # Sum of all intermediate cascade buffers (including weight buffers)
+            cascade_buffers = weight_buffer
+            # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
+            best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
+
+            # Op is the producer of the OFM consumed by the next Op to consider
+            producer = op
+            while True:
+                dependants = producer.get_dependants()
+                if len(dependants) != 1:
+                    # producer is either the last Op in the schedule or the start of a branch
+                    break
+
+                current_op = dependants[0]
+                if (
+                    current_op in cost
+                    or current_op not in ref_cost
+                    or not self._is_cascadable(current_op, ref_cost[current_op])
+                    or producer.ofm.shape != current_op.ifm.shape
+                ):
+                    # Current op has already been processed or cannot be cascaded
+                    break
+
+                # Get the size of the FeatureMap buffers between current and neighbouring Ops
+                op_full_ifm = current_op.ifm_size_in_bytes()
+                op_full_ofm = current_op.ofm_size_in_bytes()
+                _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
+
+                # Get the size of the weight buffer
+                op_weight_buffer = 0
+                if ref_cost[current_op].buffered_weight_tensor:
+                    op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+
+                # Calculate the uncascaded memory requirement for current Op
+                uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
+
+                # Add current Op to cascade
+                ops_in_cascade.append(current_op)
+
+                # Increase the accumulated intermediate buffers in the cascade
+                cascade_buffers += op_ifm_buffer + op_weight_buffer
+
+                if self.spilling:
+                    # For Dedicated SRAM only the intermediate buffers are in SRAM
+                    if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
+                        # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
+                        break
+                    else:
+                        # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
+                        ops_in_best_cascade = [op for op in ops_in_cascade]
+                        best_cascade_size = cascade_buffers
+
+                else:
+                    # Calculate the total size of the current cascade
+                    cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+
+                    # Determine if cascading search should stop
+                    if (
+                        uncascaded_sram_usage < peak_sram_usage
+                        and best_cascade_size < peak_sram_usage
+                        or (cascade_ifm_size + cascade_buffers) > best_cascade_size
+                    ):
+                        # Both the existing cascade and current Op fits
+                        break
+
+                    # Determine if current cascade is the best so far
+                    if cascade_size < best_cascade_size:
+                        best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+                        ops_in_best_cascade = [op for op in ops_in_cascade]
+
+                producer = current_op
+
+            if len(ops_in_best_cascade) > 1:
+                # A cascade was created - assign cascade and ref_cost to all of the Ops
+                cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
+                buffers_in_cascade = {}
+                prev_op = None
+                for cascaded_op in ops_in_best_cascade:
+                    cost[cascaded_op] = ref_cost[cascaded_op]
+                    cost[cascaded_op].cascade = cascade_end
+                    if prev_op:
+                        rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
+                        buffers_in_cascade[cascaded_op] = rolling_buffer_shape
+
+                    prev_op = cascaded_op
+
+                # Create a CascadeInfo for the cascade
+                cascade_map[cascade_end] = CascadeInfo(
+                    cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
+                )
+                if not self.spilling:
+                    # Update peak memory usage
+                    peak_sram_usage = max(best_cascade_size, peak_sram_usage)
+            else:
+                # Assign fallback cost to the initial Op
+                cost[op] = fallback_cost[op]
+                if not self.spilling:
+                    peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
+
+        # Update costing and cascde information for the ref_schedule
+        ref_schedule.cost_map = cost
+        ref_schedule.cascades = cascade_map
+        return ref_schedule
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index e2c71ac..a9e3839 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -21,7 +21,6 @@
 from . import graph_optimiser
 from . import high_level_command_stream_generator
 from . import high_level_command_to_npu_op
-from . import insert_dma
 from . import live_range
 from . import lut
 from . import mark_tensors
@@ -30,14 +29,14 @@
 from . import pass_packing
 from . import scheduler
 from . import tensor_allocation
-from . import weight_compressor
 from .debug_database import DebugDatabase
-from .errors import VelaError
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
 from .operation import Op
 from .rewrite_graph import verify_graph_health
 from .rewrite_graph import visit_graph_post_order
+from .scheduler import OptimizationStrategy
+from .tensor import MemArea
 from .tensor import MemType
 from .tensor import Tensor
 
@@ -135,6 +134,18 @@
         DebugDatabase.add_source(op)
 
 
+def _check_schedule(nng, arch, scheduler_options):
+    # check sram usage for optimisation strategy
+    sram_usage = nng.get_root_subgraph().memory_used.get(MemArea.Sram)
+    if sram_usage is not None and scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
+        if sram_usage > scheduler_options.optimization_sram_limit:
+            print(
+                f"Warning: SRAM target for arena memory area exceeded."
+                f" Target = {scheduler_options.optimization_sram_limit} Bytes,"
+                f" Actual = {sram_usage} Bytes"
+            )
+
+
 def compiler_driver(nng, arch, options, scheduler_options):
     assert verify_graph_health(nng)
 
@@ -150,8 +161,6 @@
 
     nng = mark_tensors.mark_tensor_purpose(nng, arch, options.verbose_tensor_purpose)
     assert verify_graph_health(nng)
-    nng = insert_dma.insert_dma_commands(nng, arch, options.verbose_graph)
-    assert verify_graph_health(nng)
     pass_packing.pack_into_passes(nng, arch, options.verbose_packing)
     assert verify_graph_health(nng)
 
@@ -162,20 +171,14 @@
         start = time.time()
 
     # Run the scheduler
-    scheduler.schedule_passes(nng, arch, scheduler_options)
+    scheduler.schedule_passes(nng, arch, options, scheduler_options)
+    _check_schedule(nng, arch, scheduler_options)
 
     if options.timing:
         stop = time.time()
         print("Scheduling took %f s" % (stop - start))
         start = time.time()
 
-    # Update the compressed weights now that we have determined the
-    # block config, and calc and pack the scales and biases
-    weight_compressor.update_pass_weight_and_scale_tensors(nng, arch)
-
-    if scheduler_options.cache_bias_scale_tensor:
-        scheduler.move_scales_to_fast_storage(nng, arch)
-
     # LiveRanges for constant tensors for all Npu subgraphs
     permanent_storage = arch.permanent_storage_mem_area
     lr_graph_flash = live_range.LiveRangeGraph()
@@ -188,12 +191,8 @@
     # Calculate live ranges for all constant Npu tensors, in permanent storage
     for sg in nng.subgraphs:
         if sg.placement == PassPlacement.Npu:
-            lr_graph_flash = live_range.extract_live_ranges_from_cascaded_passes(
-                sg,
-                permanent_storage,
-                MemType.Permanent_NPU,
-                ignore_subgraph_input_output_tensors=True,
-                lr_graph=lr_graph_flash,
+            lr_graph_flash = live_range.create_linear_live_range_graph(
+                sg, permanent_storage, MemType.Permanent_NPU, lr_graph=lr_graph_flash,
             )
 
     if len(nng.subgraphs) > 1:
@@ -212,88 +211,21 @@
             lr_graph=lr_graph_flash,
         )
 
-    # Allocate all non-constant tensors to the root, i.e. Cpu, subgraph. This step
-    # will start at the root subgraph's input and traverse from top to bottom. When
-    # it comes across an Npu-op it will extract live ranges for it's corresponding
-    # Npu subgraph and add them to the root's live range graph.
-    # The non-constant tensors are stored either in arch.feature_map_storage_mem_area or
-    # arch.fast_storage_mem_area.
-    # When these memory areas are the same, all non-constant tensors are allocated together.
-    # Otherwise they are allocated separately.
-
     root_sg = nng.get_root_subgraph()
 
-    alloc_list = []
-    if arch.is_spilling_enabled():
-        mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
-        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
-        # Order is important
-        alloc_list.append(mem_alloc_scratch_fast)
-        alloc_list.append(mem_alloc_scratch)
-    else:
-        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
-        alloc_list.append(mem_alloc_scratch)
-
-    for mem_area, mem_type_set in alloc_list:
-        if arch.is_spilling_enabled() and mem_area == arch.fast_storage_mem_area:
-            # For the case where scratch_fast != scratch: attempt to place feature maps used between
-            # cascaded passes in fast storage. Bisection is used to find the max possible usage of SRAM.
-            alloc_results = []
-            while True:
-                assert len(alloc_results) < 10, "Infinite allocator loop"
-                sram_factor, dry_test = next_sram_factor(alloc_results)
-                if sram_factor is None:
-                    break
-                # Try to move as many feature maps as possible to SRAM before allocating
-                sram_limit = sram_factor * arch.sram_size
-                for sg in nng.subgraphs:
-                    scheduler.use_fast_storage_for_feature_maps(sg, sram_limit, arch)
-                alloc_success = tensor_allocation.allocate_tensors(
-                    nng,
-                    root_sg,
-                    arch,
-                    mem_area,
-                    mem_type_set,
-                    max_size=arch.sram_size,
-                    dry_test=dry_test,
-                    tensor_allocator=options.tensor_allocator,
-                    verbose_allocation=options.verbose_allocation,
-                    cpu_tensor_alignment=options.cpu_tensor_alignment,
-                )
-                if dry_test or not alloc_success:
-                    for sg in nng.subgraphs:
-                        scheduler.undo_use_fast_storage(sg, arch)
-                alloc_results.append(alloc_success)
-            if not alloc_results[-1]:
-                raise VelaError(
-                    f"Sram limit {arch.sram_size} bytes, has been exceeded by the scratch fast tensor. "
-                    "Increasing the value of --weight-estimation-scaling may help to resolve the issue. "
-                    "See OPTIONS.md for more information"
-                )
-        else:
-            tensor_allocation.allocate_tensors(
-                nng,
-                root_sg,
-                arch,
-                mem_area,
-                mem_type_set,
-                tensor_allocator=options.tensor_allocator,
-                verbose_allocation=options.verbose_allocation,
-                cpu_tensor_alignment=options.cpu_tensor_alignment,
-            )
-
     # Generate command streams and serialise Npu-ops into tensors
     for sg in nng.subgraphs:
-        high_level_command_stream_generator.generate_high_level_command_stream(
-            nng, sg, arch, options.verbose_high_level_command_stream
-        )
-        lut.optimize_high_level_cmd_stream(sg, arch)
-        high_level_command_to_npu_op.generate_register_command_stream_for_sg(
-            nng, sg, arch, options.verbose_register_command_stream
-        )
-        scratch_tens, scratch_fast_tens, flash_tens = npu_serialisation.serialise_npu_subgraph_into_tensors(
-            sg, arch, scratch_tens, scratch_fast_tens, flash_tens
-        )
+        if sg.placement == PassPlacement.Npu:
+            high_level_command_stream_generator.generate_high_level_command_stream_for_schedule(
+                nng, sg, arch, options.verbose_high_level_command_stream
+            )
+            lut.optimize_high_level_cmd_stream(sg, arch)
+            high_level_command_to_npu_op.generate_register_command_stream_for_sg(
+                nng, sg, arch, options.verbose_register_command_stream
+            )
+            scratch_tens, scratch_fast_tens, flash_tens = npu_serialisation.serialise_npu_subgraph_into_tensors(
+                sg, arch, scratch_tens, scratch_fast_tens, flash_tens
+            )
 
     npu_serialisation.rewrite_npu_call_ops(root_sg, arch)
 
@@ -316,4 +248,4 @@
         cpu_tensor_alignment=options.cpu_tensor_alignment,
     )
 
-    npu_performance.calc_performance_for_network(nng, arch)
+    npu_performance.calc_new_performance_for_network(nng, arch)
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index 918ca0a..9553c80 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -21,7 +21,7 @@
     """Base class for vela exceptions"""
 
     def __init__(self, data):
-        self.data = f"Error! {data}"
+        self.data = f"Error: {data}"
         self.error_msg = data
 
     def __str__(self):
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f4472f9..d2598ae 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -141,6 +141,7 @@
         new_op = Operation(Op.SplitSliceRead, split_op.name)
         new_op.inputs = [inp]
         ofm_shape_idx = 0
+        read_shape = offset_end
 
         # For Split the offset cannot be extracted from the tensor so it has to
         # be calculated from the index of the output tensor
@@ -160,11 +161,13 @@
 
                 if out == tens:
                     ofm_shape_idx = idx
+                    read_shape = split_op.ofm_shapes[idx]
                     break
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
 
         new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
+        new_op.read_shapes[0] = read_shape
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
         new_op.ifm_shapes.append(Shape4D(inp.shape))
@@ -189,10 +192,12 @@
             cons_op = op.ofm.consumer_list[0]
             if cons_op.ifm == op.ofm:
                 cons_op.read_offsets[0] = op.read_offsets[0]
+                cons_op.read_shapes[0] = op.read_shapes[0]
                 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
                 cons_op.ifm_shapes[0] = op.ifm_shapes[0]
             elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
                 cons_op.read_offsets[1] = op.read_offsets[0]
+                cons_op.read_shapes[1] = op.read_shapes[0]
                 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
                 cons_op.ifm_shapes[1] = op.ifm_shapes[0]
 
@@ -212,6 +217,7 @@
             avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
             avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
             avgpool_op.read_offsets[0] = op.read_offsets[0]
+            avgpool_op.read_shapes[0] = op.read_shapes[0]
 
             op.ifm.consumer_list.remove(op)
             DebugDatabase.add_optimised(op, avgpool_op)
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 19a363c..d353b48 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -41,6 +41,7 @@
         npu_block_type: NpuBlockType,
         concat_offsets: List[int],
         split_offset: Shape4D = None,
+        split_shape: Shape4D = None,
         k_height: int = 1,
         upscaling_factor: int = 1,
     ):
@@ -55,12 +56,14 @@
                 new_start_coord[idx] += split_offset[idx]
                 new_end_coord[idx] += split_offset[idx]
 
-        if (split_offset is None) and (
-            npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)
-        ):
+        if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
             # these types of operations do a "dot product" or sum over the entire IFM
-            new_start_coord[-1] = 0
-            new_end_coord[-1] = ifm_shape.depth
+            if split_offset is None:
+                new_start_coord[-1] = 0
+                new_end_coord[-1] = ifm_shape.depth
+            else:
+                new_start_coord[-1] = split_offset[-1]
+                new_end_coord[-1] = new_start_coord[-1] + split_shape[-1]
 
         if len(new_end_coord) >= 1:
             new_end_coord[-1] = min(new_end_coord[-1], ifm_shape.depth)
@@ -126,6 +129,14 @@
 
         return Box(start, end)
 
+    def is_subbox_of(self, other):
+        if self.start_coord and self.end_coord:
+            assert len(self.start_coord) == len(other.start_coord)
+            assert len(self.end_coord) == len(other.end_coord)
+            return all(a >= b for (a, b) in zip(self.start_coord, other.start_coord)) and all(
+                a <= b for (a, b) in zip(self.end_coord, other.end_coord)
+            )
+
     def get_size_shape(self):
         return [int(self.end_coord[i] - self.start_coord[i]) for i in range(len(self.end_coord))]
 
@@ -142,9 +153,6 @@
 
 
 class Command:
-    def get_ofm_y_range_for_pass(self, ps_requested):
-        return None
-
     def is_npu_pass_command(self):
         return False
 
@@ -158,8 +166,6 @@
         self,
         ps,
         block_config,
-        is_first,
-        is_last,
         is_first_h_stripe,
         is_last_h_stripe,
         ifm_tensor,
@@ -168,7 +174,6 @@
         ofm_box,
         weight_tensor=None,
         weight_box=None,
-        scale_tensor=None,
         ifm2_tensor=None,
         ifm2_box=None,
         pad_top=0,
@@ -176,8 +181,6 @@
     ):
         self.ps = ps
         self.block_config = block_config
-        self.is_first = is_first
-        self.is_last = is_last
         self.is_first_h_stripe = is_first_h_stripe
         self.is_last_h_stripe = is_last_h_stripe
         self.ifm_tensor = ifm_tensor
@@ -187,7 +190,6 @@
         self.ofm_tensor = ofm_tensor
         self.ofm_box = ofm_box
         self.weight_tensor = weight_tensor
-        self.scale_tensor = scale_tensor
         self.weight_box = weight_box
         self.pad_top = pad_top
         self.pad_bottom = pad_bottom
@@ -209,13 +211,6 @@
 
     __repr__ = __str__
 
-    def get_ofm_y_range_for_pass(self, ps_requested):
-        if ps_requested != self.ps:
-            return None
-        if len(self.ofm_box.start_coord) >= 3:
-            return (self.ofm_box.start_coord[-3], self.ofm_box.end_coord[-3])
-        return None
-
     def get_block_dimensions(self):
         ofm_box = self.ofm_box
         block_config = self.block_config
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index c01790a..ecd375e 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -14,15 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # Description:
-# Generate a high-level command stream from a scheduled subgraph with CascadedPasses.
-#
-# Also used during scheduling to work out allowable IFM/OFM overlap, this functionality can be accessed using
-# calc_allowed_ofm_ifm_overlap_for_cascaded_pass().
+# Generate a high-level command stream from a schedule
 from .high_level_command_stream import Box
 from .high_level_command_stream import DMA
 from .high_level_command_stream import NpuStripe
-from .nn_graph import PassPlacement
-from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up_divide
 from .operation import create_activation_function
 from .operation import NpuBlockType
@@ -32,326 +27,192 @@
 
 
 def dma_if_necessary(ps, box, tensor):
-    if tensor.needs_dma():
-        dma_op = tensor.ops[0]
-        in_tensor = dma_op.inputs[0]
-        yield DMA(ps, in_tensor, tensor, box)
+    src_tensor = tensor.src_tensor
+    if src_tensor and tensor.mem_area != src_tensor.mem_area:
+        yield DMA(ps, src_tensor, tensor, box)
 
 
-def generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx):
-    is_first = idx == 0
-    is_last = idx == len(passes) - 1
-    ps = passes[idx]
-    block_config = block_configs[idx]
-    npu_block_type = ps.npu_block_type
-    split_offsets = list(ps.primary_op.read_offsets)  # offset for [ifm, ifm2]
-
-    if (
-        len(ps.inputs) == 2
-        and ps.ifm_tensor is not None
-        and ps.ifm2_tensor is not None
-        and npu_block_type == NpuBlockType.ElementWise
-    ):
-        # Ensure correct ifm and ifm2 order
-        if ps.inputs[0] == ps.primary_op.inputs[1] and ps.inputs[1] == ps.primary_op.inputs[0]:
-            ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor
-            ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
-
-    ifm_tensor = ps.ifm_tensor
-    ifm_shape = None
-    if ifm_tensor.shape != []:
-        ifm_shape = ps.ifm_shapes[0]
-    ifm2_tensor = ps.ifm2_tensor
-    ifm2_shape = None
-    if ifm2_tensor is not None and ifm2_tensor.shape != []:
-        ifm2_shape = ps.ifm_shapes[1]
-    ofm_tensor = ps.ofm_tensor
-    ofm_shape = ps.ofm_shapes[0]
-    weight_tensor = ps.weight_tensor
-    scale_tensor = ps.scale_tensor
-
-    ofm_start = [0, 0, 0, 0]
-    ofm_end = ofm_shape.as_list()
-
-    strides = None
-    skirt = None
-    upscaling = 1
-    if ps.primary_op is not None:
-        strides = ps.primary_op.attrs.get("strides", None)
-        skirt = ps.primary_op.attrs.get("skirt", None)
-        if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
-            upscaling = ofm_shape.height // ifm_shape.height
-        elif ps.primary_op.type == Op.ResizeBilinear:
-            upscaling = round_up_divide(ofm_shape.height, ifm_shape.height)
-
-    concat_offset = [0, 0, 0, 0]
-
-    for op in ps.ops:
-        if op.write_offset is not None:
-            concat_offset = op.write_offset.as_list()
-            ofm_start = concat_offset[:]
-            ofm_end = (op.write_offset + op.write_shape).as_list()
-        if op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
-            ps.primary_op.activation = create_activation_function(op.type)
-
-    if strat == SchedulingStrategy.WeightStream:
-        ofm_step = block_config[-1]
-        ofm_stop = ofm_end[-1]
-        if weight_tensor is None or not weight_tensor.needs_dma():
-            ofm_step = ofm_stop
-        for start in range(ofm_start[-1], ofm_stop, ofm_step):
-            end = min(start + ofm_step, ofm_stop)
-            ofm_start[-1] = start
-            ofm_end[-1] = end
-            ofm_box = Box(ofm_start, ofm_end)
-            ifm_box = None
-            ifm2_box = None
-
-            if ifm_shape is not None:
-                ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides, skirt, ifm_shape, npu_block_type, concat_offset, split_offsets[0], upscaling,
-                )
-            else:
-                ifm_box = Box([], [])
-            if ifm2_shape is not None:
-                ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides, skirt, ifm2_shape, npu_block_type, concat_offset, split_offsets[1], upscaling,
-                )
-            else:
-                ifm2_box = Box([], [])
-
-            for intermediate in ps.intermediates:
-                if (
-                    intermediate is not None
-                    and intermediate.shape != []
-                    and intermediate.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT)
-                ):
-                    if intermediate.purpose is TensorPurpose.FeatureMap:
-                        intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                            strides,
-                            skirt,
-                            Shape4D(intermediate.shape),
-                            npu_block_type,
-                            concat_offset,
-                            split_offsets[0],
-                            upscaling,
-                        )
-                    else:
-                        intermediate_box = Box([0] * len(intermediate.shape), list(intermediate.shape))
-                    yield from dma_if_necessary(ps, intermediate_box, intermediate)
-
-            weight_box = None
-            if weight_tensor is not None:
-                weight_offset = concat_offset[len(weight_tensor.shape) - 1]
-                weight_oc_start = start - weight_offset
-                weight_oc_end = end - weight_offset
-
-                weight_box = Box.make_weight_box(
-                    weight_tensor.shape,
-                    npu_block_type,
-                    weight_oc_start,
-                    weight_oc_end,
-                    weight_tensor.weight_transpose_depthwise,
-                )
-                yield from dma_if_necessary(ps, weight_box, weight_tensor)
-
-            yield NpuStripe(
-                ps,
-                block_config,
-                is_first,
-                is_last,
-                True,
-                True,
-                ifm_tensor,
-                ifm_box,
-                ofm_tensor,
-                ofm_box,
-                weight_tensor,
-                weight_box,
-                scale_tensor,
-                ifm2_tensor=ifm2_tensor,
-                ifm2_box=ifm2_box,
-            )
-
-    elif strat == SchedulingStrategy.IfmStream:
-        assert ifm_shape is not None
-        y_step = block_config[0]
-        y_start = ofm_start[-3]
-        y_dim = ofm_end[-3]
-
-        if idx > 0:
-            ifm_y_present = 0
-            prev_pass = passes[idx - 1]
-            prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1)
-        else:
-            ifm_y_present = 1
-            ifm_y_present = ifm_shape.height
-            prev_pass_gen = []
-            prev_pass = None
-
-        if len(passes) == 1:
-            # no cascading, can just issue one big stripe
-            # but only if we've done allocation and OFM does not overlap IFM
-            if ifm_tensor.address is not None and ofm_tensor.address is not None:
-                if (
-                    ifm_tensor.address + ifm_tensor.storage_size() <= ofm_tensor.address
-                    or ofm_tensor.address + ofm_tensor.storage_size() <= ifm_tensor.address
-                ):
-                    y_step = y_dim
-
-        weight_box = None
-        scale_box = None
-
-        for start in range(y_start, y_dim, y_step):
-            end = min(start + y_step, y_dim)
-            ofm_start[-3] = start
-            ofm_end[-3] = end
-            ofm_box = Box(ofm_start, ofm_end)
-
-            k_height = 1
-            if npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
-                if ps.primary_op is not None:
-                    k_height = ps.primary_op.attrs["ksize"][1]
-            else:
-                if weight_tensor is not None:
-                    k_height = weight_tensor.shape[0]
-
-            ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
-                strides, skirt, ifm_shape, npu_block_type, concat_offset, split_offsets[0], k_height, upscaling,
-            )
-
-            ifm_y_needed = 1
-            if len(ifm_box.end_coord) >= 3:
-                ifm_y_needed = ifm_box.end_coord[-3]
-            if ifm_y_present < ifm_y_needed:
-                for prev_cmd in prev_pass_gen:
-                    yield prev_cmd
-                    rng = prev_cmd.get_ofm_y_range_for_pass(prev_pass)
-                    if rng is not None:
-                        ifm_y_present = max(ifm_y_present, rng[1])
-                        if ifm_y_present >= ifm_y_needed:
-                            break
-
-            for intermediate in ps.intermediates:
-                if (
-                    intermediate is not None
-                    and intermediate.shape != []
-                    and intermediate.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT)
-                ):
-                    if intermediate.purpose is TensorPurpose.FeatureMap:
-                        intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                            strides,
-                            skirt,
-                            Shape4D(intermediate.shape),
-                            npu_block_type,
-                            concat_offset,
-                            split_offsets[0],
-                            upscaling,
-                        )
-                    else:
-                        intermediate_box = Box([0] * len(intermediate.shape), list(intermediate.shape))
-                    yield from dma_if_necessary(ps, intermediate_box, intermediate)
-
-            if scale_tensor is not None and scale_tensor.purpose == TensorPurpose.FSBias and scale_box is None:
-                scale_box = Box([0] * len(scale_tensor.shape), list(scale_tensor.shape))
-                yield from dma_if_necessary(ps, scale_box, scale_tensor)
-
-            if weight_tensor is not None and weight_box is None:
-                weight_box = Box.make_weight_box(
-                    weight_tensor.shape, npu_block_type, weights_transposed=weight_tensor.weight_transpose_depthwise
-                )
-                yield from dma_if_necessary(ps, weight_box, weight_tensor)
-
-            # Check if first/last stripe in pass
-            is_first_h_stripe = start == y_start
-            is_last_h_stripe = (start + y_step) >= y_dim
-
-            stripe = NpuStripe(
-                ps,
-                block_config,
-                is_first,
-                is_last,
-                is_first_h_stripe,
-                is_last_h_stripe,
-                ifm_tensor,
-                ifm_box,
-                ofm_tensor,
-                ofm_box,
-                weight_tensor,
-                weight_box,
-                scale_tensor,
-                None,
-                None,
-                pad_top,
-                pad_bottom,
-            )
-            yield stripe
-    else:
-        assert 0, "unknown scheduling strategy"
-
-
-def generate_high_level_command_stream_for_pass_list(strat, passes, block_configs):
-    if strat == SchedulingStrategy.WeightStream:
-        for idx in range(len(passes)):
-            yield from generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx)
-    elif strat == SchedulingStrategy.IfmStream:
-        yield from generate_high_level_command_stream_for_pass(strat, passes, block_configs, len(passes) - 1)
-    else:
-        assert 0, "Unknown streaming strategy"
-
-
-def generate_high_level_command_stream_for_cascaded_pass(cps):
-    yield from generate_high_level_command_stream_for_pass_list(
-        cps.strategy, cps.passes, [ps.block_config for ps in cps.passes]
-    )
-
-
-def generate_high_level_command_stream(nng, sg, arch, verbose_high_level_command_stream):
+def generate_high_level_command_stream_for_schedule(nng, sg, arch, verbose_high_level_command_stream):
     res = []
-    for cps in sg.cascaded_passes:
-        if cps.placement == PassPlacement.Npu:
-            res += list(generate_high_level_command_stream_for_cascaded_pass(cps))
+    # sg.sched_ops are ordered by execution
+    processed_cascades = set()
+    for sched_op in sg.sched_ops:
+        op_info = sg.schedule.cost_map[sched_op]
+        if op_info.cascade in processed_cascades:
+            # This cascade has already been processed
+            continue
+
+        if op_info.cascade == 0:
+            # Generate high-level commands for this Op in isolation
+            res += list(generate_high_level_commands_for_sched_op(sched_op, sg.schedule))
+        else:
+            # Generate high-level commands for the whole cascade
+            cascade_info = sg.schedule.cascades[op_info.cascade]
+            # Start from the last Op in the cascade
+            res += list(generate_high_level_commands_for_sched_op(sg.sched_ops[cascade_info.end], sg.schedule))
+            processed_cascades.add(op_info.cascade)
 
     sg.high_level_command_stream = res
     if verbose_high_level_command_stream:
         sg.print_high_level_command_stream()
 
 
-def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs):
-    highest_ofm_write = 0
-    if not passes[0].ifm_tensor or not passes[-1].ofm_tensor:
-        return 0
+def generate_high_level_commands_for_sched_op(sched_op, schedule):
+    op_info = schedule.cost_map[sched_op]
+    cascade_info = schedule.cascades.get(op_info.cascade)
+    npu_block_type = sched_op.parent_ps.npu_block_type
+    block_config = op_info.block_config
+    ps = sched_op.parent_ps
+    parent_op = sched_op.parent_op
+    ofm_tensor = ps.ofm_tensor
 
-    ifm_read = passes[0].ifm_tensor.storage_size()
-    min_overlap = 999999999999999999999
-    ofm_size = passes[-1].ofm_tensor.storage_size()
-    if strat == SchedulingStrategy.WeightStream:
-        return 0
-    for cmd in generate_high_level_command_stream_for_pass_list(strat, passes, block_configs):
-        if cmd.is_npu_pass_command():
-            if cmd.is_first:
-                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0], is_top_box=False
+    # Get Tensors and Full Shapes
+    (ifm_tensor, ifm2_tensor, uncomp_weight_tensor, _, _,) = parent_op.get_ifm_ifm2_weights_biases_ofm()
+    ifm = sched_op.ifm
+    ifm2 = sched_op.ifm2
+    ofm_shape = sched_op.ofm.shape
+
+    # Get Kernel strides and upscaling factor
+    kernel_stride = sched_op.kernel.stride
+    strides = [1, kernel_stride.y, kernel_stride.x, 1]
+    skirt = parent_op.attrs.get("skirt", None)
+    upscaling = 1
+    if sched_op.op_type == Op.Conv2DBackpropInputSwitchedBias:
+        upscaling = ofm_shape.height // ifm.shape.height
+    elif sched_op.op_type == Op.ResizeBilinear:
+        upscaling = round_up_divide(ofm_shape.height, ifm.shape.height)
+
+    # Get Kernel height
+    k_height = 1
+    if npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
+        if parent_op is not None:
+            k_height = parent_op.attrs["ksize"][1]
+    else:
+        if uncomp_weight_tensor is not None:
+            k_height = uncomp_weight_tensor.shape[0]
+
+    # Define Start and End coordinates for the OFM
+    ofm_start = Shape4D(0, 0, 0, op_info.ofm_depth_slices[0])
+    ofm_end = ofm_shape
+
+    ofm_depth_slices = op_info.ofm_depth_slices
+
+    # Read/Write offsets
+    read_offsets = list(parent_op.read_offsets)  # offset for [ifm, ifm2]
+    read_shapes = list(parent_op.read_shapes)  # read shapes for [ifm, ifm2]
+    write_offset = Shape4D(0, 0, 0, 0)
+    if parent_op.write_offset is not None:
+        write_offset = parent_op.write_offset
+        ofm_start = write_offset
+        ofm_end = parent_op.write_offset + parent_op.write_shape
+
+    # Create activation function if needed
+    for op in ps.ops:
+        if op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
+            ps.primary_op.activation = create_activation_function(op.type)
+
+    # Generate commands for the Op that produces this Op's IFM, if applicable
+    if cascade_info is None or cascade_info.start == sched_op.index:
+        # Lone Op or First Op in cascade - all IFM data is present
+        ifm_present = Box([0, 0, 0, 0], ifm.shape.as_list())
+        producer_op = None
+        prev_cmd_gen = []
+    else:
+        ifm_present = Box([0, 0, 0, 0], [0, 0, 0, 0])
+        producer_op = sched_op.ifm.connection.producers[0]
+        prev_cmd_gen = generate_high_level_commands_for_sched_op(producer_op, schedule)
+
+    ofm_step = op_info.stripe
+    for start_height in range(ofm_start.height, ofm_end.height, ofm_step.height):
+        end_height = min(start_height + ofm_step.height, ofm_end.height)
+        for start_width in range(ofm_start.width, ofm_end.width, ofm_step.width):
+            end_width = min(start_width + ofm_step.width, ofm_end.width)
+
+            for depth_idx, start_channel in enumerate(ofm_depth_slices[:-1]):
+                start_channel = max(start_channel, ofm_start.depth)
+                end_channel = min(ofm_depth_slices[depth_idx + 1], ofm_end.depth)
+
+                # Construct the OFM box for the current stripe
+                ofm_box_start = Shape4D(ofm_start.batch, start_height, start_width, start_channel)
+                ofm_box_end = Shape4D(ofm_end.batch, end_height, end_width, end_channel)
+                ofm_box = Box(ofm_box_start.as_list(), ofm_box_end.as_list())
+                ifm_box = Box([], [])
+                ifm2_box = Box([], [])
+
+                # Calculate IFM input box based on the OFM box
+                if ifm:
+                    ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
+                        strides,
+                        skirt,
+                        ifm.shape,
+                        npu_block_type,
+                        write_offset.as_list(),
+                        read_offsets[0],
+                        read_shapes[0],
+                        k_height,
+                        upscaling,
+                    )
+
+                # Calculate IFM2 input box based on the OFM box
+                if ifm2:
+                    ifm2_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
+                        strides,
+                        skirt,
+                        ifm2.shape,
+                        npu_block_type,
+                        write_offset.as_list(),
+                        read_offsets[1],
+                        read_shapes[1],
+                        k_height,
+                        upscaling,
+                    )
+
+                ifm_required = ifm_box
+                # Get the Op that produces this Op's IFM data - only applicable within cascades
+                if producer_op:
+                    assert op_info.cascade != 0
+                    assert op_info.cascade == schedule.cost_map[producer_op].cascade
+                    for prev_cmd in prev_cmd_gen:
+                        yield prev_cmd
+                        if prev_cmd.is_npu_pass_command() and prev_cmd.ps == producer_op.parent_ps:
+                            ifm_present.end_coord = prev_cmd.ofm_box.end_coord
+                            if ifm_required.is_subbox_of(ifm_present):
+                                # There is enough IFM data - exit loop
+                                break
+
+                # Information about the current stripe's location in the cascade
+                is_first_h_stripe = ofm_box_start.height == ofm_start.height
+                is_last_h_stripe = ofm_box_end.height >= ofm_end.height
+
+                # Calculate the weight box - i.e. the subshape of weights needed for this NpuStripe command
+                weight_tensor = op_info.npu_weights_tensor
+                if op_info.npu_weights_tensor:
+                    weight_box = Box([0, 0, 0, start_channel], [1, 1, 1, end_channel])
+
+                    if op_info.buffered_weight_tensor and is_first_h_stripe:
+                        yield from dma_if_necessary(sched_op.parent_ps, weight_box, op_info.buffered_weight_tensor)
+                        weight_tensor = op_info.buffered_weight_tensor
+                else:
+                    weight_box = None
+
+                if parent_op.activation_lut:
+                    lut_tensor = [tens for tens in parent_op.inputs if tens.purpose == TensorPurpose.LUT][0]
+                    lut_box = Box([0] * len(lut_tensor.shape), list(lut_tensor.shape))
+                    yield from dma_if_necessary(sched_op.parent_ps, lut_box, lut_tensor)
+
+                yield NpuStripe(
+                    sched_op.parent_ps,
+                    block_config.old_style_representation(),
+                    is_first_h_stripe,
+                    is_last_h_stripe,
+                    ifm_tensor,
+                    ifm_box,
+                    ofm_tensor,
+                    ofm_box,
+                    weight_tensor,
+                    weight_box,
+                    ifm2_tensor=ifm2_tensor,
+                    ifm2_box=ifm2_box,
+                    pad_top=pad_top,
+                    pad_bottom=pad_bottom,
                 )
-                if ifm_read is None:
-                    return 0
-            if cmd.is_last:
-                write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
-                    cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0], is_top_box=True
-                )
-                if write_offset is None:
-                    return 0
-                highest_ofm_write = max(write_offset, highest_ofm_write)
-
-            if cmd.is_first or cmd.is_last:
-                overlap_required = max(highest_ofm_write - min(ifm_read, ofm_size), 0)
-                can_overwrite = ofm_size - overlap_required
-                min_overlap = min(min_overlap, can_overwrite)
-
-            if cmd.is_first:
-                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0], is_top_box=True
-                )
-
-    min_overlap = max(min_overlap, 0)
-    return min_overlap
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index ad9e266..4ef7bee 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -51,6 +51,7 @@
 from .high_level_command_stream import Command
 from .high_level_command_stream import DMA
 from .high_level_command_stream import NpuStripe
+from .numeric_util import round_up
 from .operation import NpuBlockType
 from .operation import Op
 from .operation import Operation
@@ -61,9 +62,10 @@
 from .shape4d import Shape4D
 from .tensor import MemType
 from .tensor import Tensor
-from .tensor import TensorBlockTraversal
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
+from .tensor import TensorSubPurpose
+from .weight_compressor import WeightKey
 
 
 class BasePointerIndex(IntEnum):
@@ -81,12 +83,6 @@
 }
 
 
-block_traversal_map = {
-    TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
-    TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
-}
-
-
 # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
 elementwise_op_map = {
     Op.Mul: NpuElementWiseOp.MUL,
@@ -272,44 +268,46 @@
 
 
 def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
-    """Returns address ranges for weights"""
+    """Returns address ranges for weights and scales"""
     weights = []
-    stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
-    weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
-    substreams = len(weight_substream_offsets) - 1  # Offset list must terminate with full stream length
-
-    # Extract weight substream offsets and calculate their lengths
-    assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
-    weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
-    region = get_region(weight_tensor.mem_type, arch)
-    for core in range(substreams):
-        address = weight_addr + weight_substream_offsets[core]
-        length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
-        addr_range = NpuAddressRange(region, int(address), int(length))
-        weights.append(addr_range)
-    return weights
-
-
-def create_biases(
-    weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
-) -> List[NpuAddressRange]:
-    """Returns address ranges for biases"""
     biases = []
-    stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
-    scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
-    substreams = len(scale_substream_offsets) - 1  # Offset list must terminate with full stream length
+    region = get_region(weight_tensor.mem_type, arch)
 
-    # Extract scale substream offsets and calculate their lengths
-    assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
-    scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
+    w_tensor_src = weight_tensor
+    if weight_tensor.src_tensor:
+        w_tensor_src = weight_tensor.src_tensor
 
-    region = get_region(scale_tensor.mem_type, arch)
-    for core in range(substreams):
-        address = scale_addr + scale_substream_offsets[core]
-        length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
-        addr_range = NpuAddressRange(region, int(address), int(length))
-        biases.append(addr_range)
-    return biases
+    core_offset = 0
+    for core in range(0, arch.ncores):
+        # Get weight range per core
+        key = WeightKey(core, weight_box.start_coord[-1])
+        if key in w_tensor_src.encoded_ranges:
+            weight_range = w_tensor_src.encoded_ranges[key]
+            if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
+                assert weight_tensor != w_tensor_src
+                # Double buffered inside weight_tensor
+                address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
+                address += core_offset
+                core_offset += round_up(weight_range.total_bytes, 16)
+            else:
+                if weight_tensor == w_tensor_src:
+                    # Straight from source tensor
+                    address = weight_tensor.address + weight_range.offset
+                else:
+                    # Single buffered inside weight tensor
+                    address = weight_tensor.address + core_offset
+                    core_offset += round_up(weight_range.total_bytes, 16)
+
+            # Location of weights in tensor
+            addr_range = NpuAddressRange(
+                region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
+            )
+            weights.append(addr_range)
+            # Location of biases in tensor
+            addr_range = NpuAddressRange(region, int(address), round_up(int(weight_range.scale_bytes), 16))
+            biases.append(addr_range)
+
+    return weights, biases
 
 
 def create_npu_activation(op: Operation) -> NpuActivation:
@@ -353,9 +351,7 @@
     npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
 
     if cmd.weight_tensor is not None:
-        npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
-        if cmd.scale_tensor is not None:
-            npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
+        npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
     npu_op.activation = create_npu_activation(op)
     npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
     npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
@@ -375,7 +371,10 @@
     if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
         npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
     else:
-        npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
+        if cmd.weight_tensor.src_tensor:
+            npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
+        else:
+            npu_op.block_traversal = cmd.weight_tensor.hw_traversal
     return npu_op
 
 
@@ -464,17 +463,29 @@
     else:
         dest_region = get_region(cmd.out_tensor.mem_type, arch)
 
-    start_coord = cmd.box.start_coord
-    src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
-    dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
+    if cmd.in_tensor.purpose == TensorPurpose.Weights:
+        # Get weight range per core
+        sz = 0
+        for core in range(0, arch.ncores):
+            key = WeightKey(core, cmd.box.start_coord[-1])
+            if key in cmd.in_tensor.encoded_ranges:
+                weight_range = cmd.in_tensor.encoded_ranges[key]
+                sz += round_up(weight_range.total_bytes, 16)
 
-    if cmd.in_tensor.compressed_values is not None:
-        if cmd.out_tensor.purpose == TensorPurpose.FSBias:
-            sz = cmd.in_tensor.storage_size()
-        else:
-            stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
-            sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
+                if core == 0:
+                    weight_range = cmd.in_tensor.encoded_ranges[key]
+                    src_addr = cmd.in_tensor.address + weight_range.offset
+
+                    if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
+                        dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
+                            (weight_range.index - core) % 2
+                        )
+                    else:
+                        dest_addr = cmd.out_tensor.address
     else:
+        start_coord = cmd.box.start_coord
+        src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
+        dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
         sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
     src = NpuAddressRange(src_region, int(src_addr), int(sz))
     dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
deleted file mode 100644
index bbe18f7..0000000
--- a/ethosu/vela/insert_dma.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
-#
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the License); you may
-# not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an AS IS BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Description:
-# Insert DMA operations into the graph for transfering weights.
-from . import rewrite_graph
-from .operation import NpuBlockType
-from .operation import Op
-from .operation import Operation
-from .tensor import MemArea
-from .tensor import MemType
-from .tensor import TensorPurpose
-from .weight_compressor import compress_weights
-
-
-def weights_fit_sram(arch, op, tens, nng):
-    # Compresses weights and checks if they fit in SRAM
-    if tens.purpose != TensorPurpose.Weights:
-        return True
-
-    min_weight_size = 0
-    if len(tens.shape) == 4:
-        min_weight_size = tens.shape[0] * tens.shape[1] * tens.shape[2] * arch.OFMSplitDepth
-    elif len(tens.shape) == 2:
-        min_weight_size = tens.shape[0] * arch.OFMSplitDepth
-
-    compress_weights(arch, nng, tens, op.type.npu_block_type, 16, 16, op.get_dilation_h_w())
-
-    # Need to be fit into Sram, as a double buffer
-    worst_buffer_size = tens.compression_scale_for_worst_weight_stream * min_weight_size * 2
-    if worst_buffer_size > arch.sram_size:
-        print(
-            "Weights, {}, are too big to be DMAed to SRAM, estimated minimum size is {} bytes".format(
-                tens.name, worst_buffer_size
-            )
-        )
-        return False
-    return True
-
-
-def insert_dma_cmd(op, arch, nng):
-    if op.type == Op.DMA or not op.run_on_npu:
-        return op
-
-    is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs)
-    max_ifm_shram_avail = (
-        (arch.available_shram_banks(is_lut_used) - arch.shram_reserved_output_banks) * arch.shram_bank_size // 2
-    )
-
-    for idx, tens in enumerate(op.inputs):
-
-        if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
-            # Tensor is in permanent storage
-            # Only when permanent storage differs from fast storage, there is a point moving the data
-            if (
-                tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
-                and arch.permanent_storage_mem_area != arch.fast_storage_mem_area
-            ) or tens.purpose == TensorPurpose.LUT:
-                if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or (
-                    tens.purpose == TensorPurpose.FeatureMap
-                    and op.type.is_binary_elementwise_op()
-                    and tens.shape != []
-                    and op.ifm_shapes[0] != op.ofm_shapes[0]
-                    and tens.storage_size() > max_ifm_shram_avail
-                ):
-                    only_vector_product_consumers = True
-                    for oper in tens.consumers():
-                        if oper is None or oper.type.npu_block_type != NpuBlockType.VectorProduct:
-                            only_vector_product_consumers = False
-                            break
-
-                    # Tensor products has no need for DMA, tensors are only read once and can be in flash.
-                    # Other operations re-reads tensors, this is better done from SRAM.
-                    # LUTs must be placed in the last 2 blocks of SHRAM.
-                    if (
-                        not only_vector_product_consumers and weights_fit_sram(arch, op, tens, nng)
-                    ) or tens.purpose == TensorPurpose.LUT:
-                        # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
-                        new_tens = tens.clone_into_fast_storage(arch)
-                        dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
-                        dma_cmd.inputs = [tens]
-                        dma_cmd.set_output_tensor(new_tens)
-                        dma_cmd.attrs["source"] = tens.mem_area
-                        dma_cmd.attrs["destination"] = new_tens.mem_area
-                        dma_cmd.run_on_npu = True
-                        if tens.purpose == TensorPurpose.LUT:
-                            new_tens.mem_area = MemArea.Shram
-                        op.inputs[idx] = new_tens
-    return op
-
-
-def insert_dma_commands(nng, arch, verbose_graph=False):
-
-    for idx, sg in enumerate(nng.subgraphs):
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [insert_dma_cmd])
-    if verbose_graph:
-        nng.print_graph("After DMA insertion")
-    return nng
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index de001e5..d75a167 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -18,10 +18,14 @@
 # Can work with either a pass packed subgraph or a scheduled subgraph.
 from typing import List
 
+import numpy as np
+
 from .nn_graph import PassPlacement
 from .operation import Op
+from .tensor import MemArea
 from .tensor import MemType
 from .tensor import Tensor
+from .tensor import TensorPurpose
 
 
 class LiveRange:
@@ -32,6 +36,7 @@
         self.size = 0
         self.name = ""
         self.alignment = alignment
+        self.mem_area = tens.mem_area if tens else MemArea.Unknown
 
         if tens:
             self.add_tensor(tens)
@@ -52,15 +57,19 @@
 
         self.tensors.append(tens)
 
-    def mark_usage(self, op_time):
-        if op_time == -1:
+    def mark_usage(self, op_time, op_length=1):
+        op_time_start = max(op_time, 0)
+        op_time_end = op_time + op_length
+        if op_time_end <= op_time_start:
             return
-        op_time_start = op_time
-        op_time_end = op_time + 1
 
         self.start_time = min(self.start_time, op_time_start)
         self.end_time = max(self.end_time, op_time_end)
 
+    def set_buffer_size(self, buffer_size):
+        self.size = buffer_size
+        self.mem_area = MemArea.Sram
+
     def overlaps_ranges(self, other):
         return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
 
@@ -106,6 +115,7 @@
         self.ignore_tensors = set()
         self.processed_subgraphs = set()
         self.current_time = 0
+        self.end_time = None
 
     def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
         # Return the live range of the tensor (or any of its clones)
@@ -127,6 +137,23 @@
         self.ranges[out_tens] = live_range
         return live_range
 
+    def update_endtime(self):
+        self.end_time = 0
+        for rng in self.ranges.values():
+            self.end_time = max(self.end_time, rng.end_time)
+        return self.end_time + 1
+
+    def get_temporal_memory_usage(self, target_mem_area):
+        if not self.end_time:
+            self.update_endtime()
+        usage = np.zeros(self.end_time, dtype=np.int32)
+        for rng in self.ranges.values():
+            if rng.mem_area == target_mem_area:
+                # End time is inclusive
+                usage[rng.start_time : rng.end_time + 1] += rng.size
+
+        return usage
+
 
 def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
     if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
@@ -279,9 +306,7 @@
             # is called. Go into said subgraph and extract live ranges before continuing.
             # Use default allocation alignment of 16 for Npu tensors
             npu_sg = cps_primary_op.attrs["subgraph"]
-            lr_graph = extract_live_ranges_from_cascaded_passes(
-                npu_sg, target_mem_area, target_mem_type_set, False, lr_graph,
-            )
+            lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
             # Set the new time after handling the Npu subgraph
             time_for_pass = lr_graph.current_time
             cps.time = time_for_pass
@@ -308,3 +333,89 @@
     # Add subgraph to set of processed subgraphs
     lr_graph.processed_subgraphs.add(sg)
     return lr_graph
+
+
+def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
+    assert lr_graph is not None
+    sg_time = lr_graph.current_time
+    for ps in sg.passes:
+        for tens in ps.inputs + ps.outputs + ps.intermediates:
+            if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
+                lr_graph, tens, target_mem_area, target_mem_type_set
+            ):
+                continue
+
+            rng = lr_graph.get_or_create_range(tens)
+            rng.mark_usage(sg_time)
+
+    for sched_op, op_info in sg.schedule.cost_map.items():
+        if op_info.npu_weights_tensor and not (
+            tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set)
+        ):
+            rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor)
+            rng.mark_usage(sg_time)
+
+    lr_graph.current_time += 1
+    return lr_graph
+
+
+def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
+    time_for_cascade = {}
+    for sched_op in sg.sched_ops:
+        op_info = sg.schedule.cost_map[sched_op]
+        cascade = op_info.cascade
+        cascade_info = sg.schedule.cascades.get(cascade, None)
+
+        time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
+
+        op_info.time_index = time_to_set
+
+        # Mark usage for all tensors related to this Pass
+        ps = sched_op.parent_ps
+        for tens in ps.inputs + ps.outputs + ps.intermediates:
+            if (
+                target_mem_area == MemArea.Sram
+                and cascade_info
+                and tens == ps.ifm_tensor
+                and sched_op in cascade_info.buffers
+            ):
+                # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
+                # for enabling temporal memory snapshots without modifying the original Tensor
+                rng = lr_graph.get_or_create_range(tens)
+                rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
+            elif (
+                tens.purpose == TensorPurpose.Weights
+                or tens.purpose == TensorPurpose.FSBias
+                or tens.mem_type not in target_mem_type_set
+                or tens.mem_area != target_mem_area
+            ):
+                continue
+
+            else:
+                rng = lr_graph.get_or_create_range(tens)
+
+            rng.mark_usage(time_to_set)
+
+        weight_tens = op_info.buffered_weight_tensor
+        if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+            rng = lr_graph.get_or_create_range(weight_tens)
+            if weight_tens.pre_buffer:
+                rng.mark_usage(time_to_set - 1, 2)
+            else:
+                rng.mark_usage(time_to_set)
+
+        if time_to_set == lr_graph.current_time:
+            lr_graph.current_time += 2
+
+        if cascade != 0:
+            time_for_cascade[cascade] = time_to_set
+
+    end_time = lr_graph.update_endtime()
+
+    for tens in sg.output_tensors:
+        if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
+            continue
+        rng = lr_graph.get_or_create_range(tens)
+        rng.mark_usage(end_time)
+
+    return lr_graph
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index f810df0..7dc2d72 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -69,6 +69,7 @@
         self.npu_block_type = npu_block_type
         self.block_config = None  # will be filled in by scheduler
         self.shared_buffer = None  # will be filled in by scheduler
+        self.scheduling_info = None  # will be filled in by scheduler
 
         self.predecessors = []
         self.successors = []
@@ -123,6 +124,7 @@
 
         self.predecessors = []
         self.successors = []
+        self.sram_used = 0
 
     def __str__(self):
         return "<nng.CascadedPass strategy=%s x %s '%s',  passes=%s, block_configs=%s>" % (
@@ -149,7 +151,9 @@
         self.command_stream_tensor = None
         self.flash_tensor = None
         # Scratch information locally used in the scheduler
-        self.scheduling_info = {}
+        self.schedule = None
+        self.sched_ops = []
+
         self.generated_stream_id = None
 
         self.memory_used = {}
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index c83f8f5..b1dae4e 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -19,45 +19,28 @@
 #
 # Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
 # estimate.
+import copy
 from enum import auto
 from enum import IntEnum
 
 import numpy as np
 
 from . import numeric_util
+from .architecture_allocator import ArchitectureBlockConfig
 from .architecture_features import Accelerator
-from .architecture_features import Block
-from .data_type import DataType
-from .nn_graph import PassPlacement
-from .nn_graph import SchedulerRewrite
-from .operation import NpuBlockType
+from .architecture_features import NpuBlockType
+from .architecture_features import SHRAMElements
+from .architecture_features import TensorFormat
+from .numeric_util import round_up
+from .operation import Kernel
 from .operation import Op
-from .shared_buffer_allocation import is_acc_40bits_used
+from .scheduler import Schedule
+from .scheduler import SchedulerOperation
+from .shape4d import Shape4D
 from .tensor import BandwidthDirection
 from .tensor import MemArea
-from .tensor import shape_num_elements
-from .tensor import Tensor
-from .tensor import TensorBlockTraversal
-from .tensor import TensorFormat
 from .tensor import TensorPurpose
-
-
-def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
-    ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
-    kernel = ps2.primary_op.kernel
-
-    if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
-        op = ps2.primary_op
-        ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0].depth, op.ifm.dtype.size_in_bits())
-    else:
-        ifm_block_depth = block_config_ps2[-1]
-
-    ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, arch.ofm_block_max)
-
-    # The performed height calculation is for worst case
-    height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
-    width = ifm_block.width
-    return [height, width]
+from .weight_compressor import WeightKey
 
 
 class PassCycles(IntEnum):
@@ -91,82 +74,173 @@
         )
 
 
-def make_bandwidth_array():
-    return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
+class PerformanceQuery:
+    def __init__(self, npu_block_type=0):
+        self.npu_block_type = npu_block_type
+        self.ifm_shape = Shape4D(0)
+        self.ifm_format = TensorFormat.NHWC
+        self.ifm_memory_area = MemArea.Unknown
+        self.ifm2_memory_area = MemArea.Unknown
+        self.ifm_bits = 0
+        self.ifm2_bits = 0
+        self.ifm2_shape = None
+        self.ifm2_format = TensorFormat.NHWC
+        self.ofm_shape = Shape4D(0)
+        self.ofm_format = TensorFormat.NHWC
+        self.ofm_memory_area = MemArea.Unknown
+        self.ofm_bits = 0
+        self.const_shape = Shape4D(0)
+        self.const_memory_area = MemArea.Unknown
+        self.kernel = Kernel(1, 1)
+        self.config = ArchitectureBlockConfig()
 
 
-def make_cycles_array():
-    return np.zeros(PassCycles.Size)
+class CycleCost:
+    def __init__(self):
+        self.op_macs = 0
+        self.op_cycles = 0
+
+    def __mul__(self, scale):
+        out = CycleCost()
+        out.op_macs = self.op_macs * scale
+        out.op_cycles = self.op_cycles * scale
+        return out
+
+    def __iadd__(self, rhs):
+        self.op_macs += rhs.op_macs
+        self.op_cycles += rhs.op_cycles
+        return self
+
+    def __str__(self):
+        return "macs = {}, cycles = {}".format(self.op_macs, self.op_cycles)
 
 
-def make_metrics_arrays():
-    return (make_bandwidth_array(), 0, make_cycles_array())
+class ElementAccess:
+    def __init__(self):
+        # List of ONLY element access counts, consumers
+        # need to scale these values by the correct bitwidths
+        # to calculated memory bandwidth
+        self.ifm_read = [0, 0]  # ifm1, ifm2
+        self.ofm_write = 0
+        self.weights_refetch = 0
+        self.const_read = [0, 0]  # weights, scales
+
+    def __mul__(self, scale):
+        out = ElementAccess()
+        out.ifm_read[0] = self.ifm_read[0] * scale
+        out.ifm_read[1] = self.ifm_read[1] * scale
+        out.ofm_write = self.ofm_write * scale
+        out.weights_refetch = self.weights_refetch * scale
+        out.const_read[0] = self.const_read[0] * scale
+        out.const_read[1] = self.const_read[1] * scale
+        return out
+
+    def __iadd__(self, rhs):
+        self.ifm_read[0] += rhs.ifm_read[0]
+        self.ifm_read[1] += rhs.ifm_read[1]
+        self.ofm_write += rhs.ofm_write
+        self.weights_refetch += rhs.weights_refetch
+        self.const_read[0] += rhs.const_read[0]
+        self.const_read[1] += rhs.const_read[1]
+        return self
+
+    def __str__(self):
+        return "ifm read = {}, ofm write = {}, const read={}".format(self.ifm_read, self.ofm_write, self.const_read)
 
 
-def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversal, ofm_blk_depth):
-    ifm_blk_depth = ofm_blk_depth
+def _strides_for_shape(shape: Shape4D, format: TensorFormat, element_bits):
+    if format == TensorFormat.NHWC:
+        strides = [0, 0, 0, 0]
+        strides[3] = element_bits / 8  # +Z
+        strides[2] = (element_bits * shape.depth) // 8  # +X
+        strides[1] = (element_bits * shape.depth * shape.width) // 8  # +Y
+        strides[0] = (element_bits * shape.depth * shape.width * shape.height) // 8  # +N
+    elif format == TensorFormat.NHCWB16:
+        strides = [0, 0, 0, 0, 0]
+        strides[4] = element_bits / 8  # +Z
+        strides[3] = (element_bits * 16) / 8  # +X
+        strides[2] = (element_bits * 16 * shape.width) / 8  # +C
+        strides[1] = (element_bits * shape.width * shape.depth) / 8  # +Y
+        strides[0] = (element_bits * shape.width * shape.depth) / 8  # +N
 
-    if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
-        if ifm_elemwidth == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
-            ifm_blk_depth = 16
-        elif ifm_elemwidth == 8:
-            ifm_blk_depth = 32
-        else:
-            ifm_blk_depth = 8
-
-    return min(ifm_depth, ifm_blk_depth)
+    return strides
 
 
-def get_minimal_cmd_cycles(
-    arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk: Block, output_cycles, ifm_shape4D, ofm_shape4D, dpu_cycles=0
+def _estimate_memory_transfer_efficiency(
+    arch, is_read, mem_area, format: TensorFormat, element_bits, block_size, shape4D, to_transfer
 ):
-    ifm_tens_blk = Tensor((1, ifm_blk.height, ifm_blk.width, ifm_blk.depth), ifm_tensor.dtype, "ifm_blk")
-    ofm_tens_blk = Tensor((1, ofm_blk.height, ofm_blk.width, ofm_blk.depth), ofm_tensor.dtype, "ofm_blk")
-    cycles_ifm_blk = (
-        estimate_memory_transfer_efficiency(
-            arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk, shape4D=ifm_shape4D
-        )
-        / arch.memory_bandwidths_per_cycle[ifm_tensor.mem_area]
-    )
-    cycles_ofm_blk = (
-        estimate_memory_transfer_efficiency(
-            arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk, shape4D=ofm_shape4D
-        )
-        / arch.memory_bandwidths_per_cycle[ofm_tensor.mem_area]
-    )
-    return (
-        arch.memory_latency[ifm_tensor.mem_area][BandwidthDirection.Read]
-        + cycles_ifm_blk
-        + dpu_cycles
-        + output_cycles
-        + arch.memory_latency[ofm_tensor.mem_area][BandwidthDirection.Write]
-        + cycles_ofm_blk
-    ) / 4
+    burst_len = 8
 
+    strides = _strides_for_shape(shape4D, format, element_bits)
 
-def estimate_output_cycles(
-    arch,
-    npu_block_type,
-    primary_op,
-    num_elems,
-    ifm_tensor,
-    ofm_tensor,
-    use_acc_40bits=False,
-    ifm2_tensor=None,
-    block_config: Block = None,
-):
-    faf = None if primary_op.activation is None else primary_op.activation.op_type
-    if npu_block_type == NpuBlockType.ElementWise and ifm_tensor.dtype == DataType.int32:
-        if ifm2_tensor is None:
-            # Unary op
-            output_perf_index = 0
+    if format == TensorFormat.NHCWB16:
+        if strides[2] == block_size.depth:  # TODO is this check corrrect for non 8-bit
+            burst_len = element_bits * block_size.depth * block_size.width
+        elif is_read:
+            burst_len = 16 * element_bits * block_size.width
         else:
-            # Binary op
-            output_perf_index = 1
-    elif primary_op.type == Op.Mul and ofm_tensor.dtype == DataType.int32:
+            burst_len = 16 * element_bits * block_size.width * arch.ncores
+    elif format == TensorFormat.NHWC:
+        if is_read:
+            if strides[3] == block_size.depth:
+                burst_len = element_bits * block_size.depth * block_size.width
+            else:
+                burst_len = element_bits * block_size.depth
+        else:
+            if block_size.depth <= 16 and strides[3] == block_size.depth:
+                burst_len = element_bits * block_size.depth * block_size.width
+            else:
+                burst_len = min(64 * 8, 16 * element_bits * arch.ncores, block_size.depth * element_bits)
+
+    burst_len = burst_len // 8  # bits->bytes
+    burst_len = min(arch.memory_burst_length[mem_area], burst_len)
+    return to_transfer * (arch.memory_burst_length[mem_area] / burst_len)
+
+
+def _estimate_minimum_memory_cycles(arch, query: PerformanceQuery):
+    # Input block HW transfer (only for elements present)
+    ifm_bytes = Shape4D.min(query.ifm_shape, query.config.ifm_block).elements()
+    cycles_ifm_blk = arch.memory_latency[query.ifm_memory_area][BandwidthDirection.Read]
+    cycles_ifm_blk = cycles_ifm_blk + (
+        _estimate_memory_transfer_efficiency(
+            arch,
+            True,
+            query.ifm_memory_area,
+            query.ifm_format,
+            query.ifm_bits,
+            query.config.ifm_block,
+            query.ifm_shape,
+            ifm_bytes,
+        )
+        / arch.memory_bandwidths_per_cycle[query.ifm_memory_area]
+    )
+    # Output block HW transfer (only for elements present)
+    ofm_bytes = Shape4D.min(query.ofm_shape, query.config.ofm_block).elements()
+    cycles_ofm_blk = arch.memory_latency[query.ofm_memory_area][BandwidthDirection.Write]
+    cycles_ofm_blk = cycles_ofm_blk + (
+        _estimate_memory_transfer_efficiency(
+            arch,
+            False,
+            query.ofm_memory_area,
+            query.ofm_format,
+            query.ofm_bits,
+            query.config.ofm_block,
+            query.ofm_shape,
+            ofm_bytes,
+        )
+        / arch.memory_bandwidths_per_cycle[query.ofm_memory_area]
+    )
+    return cycles_ifm_blk, cycles_ofm_blk
+
+
+def _estimate_output_cycles_per_element(arch, op_type: Op, faf_type: Op, query: PerformanceQuery):
+    if query.npu_block_type == NpuBlockType.ElementWise and query.ifm_bits == 32:
+        # Unary op else Binary op
+        output_perf_index = 0 if query.ifm2_shape is not None else 1
+    elif op_type == Op.Mul and query.ofm_bits == 32:
         output_perf_index = 2
-    elif primary_op.type == Op.Mul or (
-        npu_block_type
+    elif op_type == Op.Mul or (
+        query.npu_block_type
         in (
             NpuBlockType.ConvolutionMxN,
             NpuBlockType.ConvolutionDepthWise,
@@ -174,31 +248,24 @@
             NpuBlockType.ReduceSum,
             NpuBlockType.VectorProduct,
         )
-        and use_acc_40bits
+        and query.config.acc_type == SHRAMElements.Acc40
     ):
         output_perf_index = 3
-    elif primary_op.type in (Op.Add, Op.Sub):
-        input_scale = ifm_tensor.quantization.scale_f32
-        input2_scale = ifm2_tensor.quantization.scale_f32
-        output_scale = ofm_tensor.quantization.scale_f32
-
-        if "resizebilinear" in primary_op.attrs:
-            output_scale = input2_scale
-
-        if None in (input_scale, input2_scale, output_scale) or input_scale == input2_scale:
+    elif op_type in (Op.Add, Op.Sub):
+        if False:
             # Simple Add/Sub
             output_perf_index = 4
         else:
-            # Advanced Add/Sub
+            # Advanced Add/Sub TODO: Add as perf selection as operator variant
             output_perf_index = 5
-    elif primary_op.type.is_maxpool_op():
+    elif op_type.is_maxpool_op():
         output_perf_index = 6
     else:
         output_perf_index = 7
 
-    if faf in (Op.Sigmoid, Op.Tanh, Op.LUT):
+    if faf_type in (Op.Sigmoid, Op.Tanh, Op.LUT):
         activation_perf_index = 0
-    elif faf in (Op.Relu, Op.Relu6, Op.ReluN1To1):
+    elif faf_type in (Op.Relu, Op.Relu6, Op.ReluN1To1):
         activation_perf_index = 1
     else:
         activation_perf_index = 2
@@ -207,69 +274,48 @@
         arch.output_cycles_per_elem[output_perf_index], arch.activation_cycles_per_elem[activation_perf_index]
     )
 
-    if primary_op.type.is_elementwise_op() and block_config is not None:
-        num_elems_blk = block_config.width * block_config.height * block_config.depth
-        cycle_cmd = get_minimal_cmd_cycles(
-            arch,
-            ifm_tensor,
-            ofm_tensor,
-            block_config,
-            block_config,
-            num_elems_blk * cycle_per_elem,
-            primary_op.ifm_shapes[0],
-            primary_op.ofm_shapes[0],
-        )
+    if op_type.is_elementwise_op():
+        num_elems_blk = query.config.ofm_block.elements()
+        ifm_blk_cycles, ofm_blk_cycles = _estimate_minimum_memory_cycles(arch, query)
+        cycle_cmd = ifm_blk_cycles + ofm_blk_cycles
+        cycle_cmd = (cycle_cmd + cycle_per_elem * num_elems_blk) / 4  # per DPU
         cycle_per_elem = max(cycle_per_elem, cycle_cmd / num_elems_blk)
 
-    return num_elems * cycle_per_elem
+    return cycle_per_elem
 
 
-def estimate_conv_pooling_cycles(
-    arch,
-    npu_block_type,
-    primary_op,
-    ifm_block: Block,
-    ofm_block: Block,
-    block_traversal,
-    kernel_dims,
-    ifm_tensor,
-    ofm_tensor,
-    scale_tensor=None,
-):
-    ofm_ublock = Block(arch.config.ofm_ublock.width, arch.config.ofm_ublock.height, arch.config.ofm_ublock.depth)
-    ifm_tens_shape = primary_op.ifm_shapes[0]
-    ofm_tens_shape = primary_op.ofm_shapes[0]
+def _estimate_conv_cycles(arch, op_type: Op, faf_type: Op, query: PerformanceQuery):
+    ifm_block = Shape4D.min(query.ifm_shape, query.config.ifm_block)
+    ofm_block = Shape4D.min(query.ofm_shape, query.config.ofm_block)
 
     if (
         arch.config.ofm_ublock.height == 2
-        and npu_block_type
+        and query.npu_block_type
         in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
-        and ofm_tens_shape.height == 1
+        and query.ofm_shape.height == 1
         # Optimisation only applies for even width tensors
-        and ofm_tens_shape.width % 2 == 0
-        and kernel_dims[0] == 1
+        and query.ofm_shape.width % 2 == 0
+        and query.kernel.height == 1
     ):
-        ofm_ublock.width = 4
-        ofm_ublock.height = 1
-        ofm_block.height = 1
+        ofm_ublock = Shape4D(1, 1, 4, arch.config.ofm_ublock.depth)
+        ofm_block = ofm_block.with_height(1)
+    else:
+        ofm_ublock = Shape4D(arch.config.ofm_ublock.to_hwc())
 
     num_ublk_x = numeric_util.round_up_divide(ofm_block.width, ofm_ublock.width)
-    num_ublk_y = ofm_block.height // ofm_ublock.height
+    num_ublk_y = numeric_util.round_up_divide(ofm_block.height, ofm_ublock.height)
     num_ublk_xy = num_ublk_x * num_ublk_y
-    num_ublk_z = ofm_block.depth // ofm_ublock.depth
-    num_ofm_blk = 0
-    total_cycles = 0
-    num_elems_blk = ofm_block.width * ofm_block.height * ofm_block.depth
-    use_acc_40bits = is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor)
+    num_ublk_z = numeric_util.round_up_divide(ofm_block.depth, ofm_ublock.depth)
+    use_acc_40bits = query.config.acc_type == SHRAMElements.Acc40
 
-    sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
-    n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
-    n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
+    sub_kernel_limits = arch.sub_kernel_limits[query.npu_block_type]
+    n_sub_kernels_y = numeric_util.round_up_divide(query.kernel.height, sub_kernel_limits[0])
+    n_sub_kernels_x = numeric_util.round_up_divide(query.kernel.width, sub_kernel_limits[1])
     sub_kernel_x = [
-        min((kernel_dims[1] - i * sub_kernel_limits[1]), sub_kernel_limits[1]) for i in range(n_sub_kernels_x)
+        min((query.kernel.width - i * sub_kernel_limits[1]), sub_kernel_limits[1]) for i in range(n_sub_kernels_x)
     ]
     sub_kernel_y = [
-        min((kernel_dims[0] - i * sub_kernel_limits[0]), sub_kernel_limits[0]) for i in range(n_sub_kernels_y)
+        min((query.kernel.height - i * sub_kernel_limits[0]), sub_kernel_limits[0]) for i in range(n_sub_kernels_y)
     ]
     sub_kernel_size = (x * y for y in sub_kernel_y for x in sub_kernel_x)
 
@@ -277,27 +323,27 @@
     cycles_wb = 32 * ofm_ublock.depth // 8
 
     for num_kernel_elems in sub_kernel_size:
-        if npu_block_type == NpuBlockType.Pooling:
+        if query.npu_block_type == NpuBlockType.Pooling:
             num_kernel_steps = 1
             cycles = max(4, num_kernel_elems) * num_ublk_xy * num_ublk_z
-            if ifm_tensor.dtype.size_in_bits() == 16 and arch.accelerator_config != Accelerator.Ethos_U55_32:
+            if query.ifm_bits == 16 and arch.accelerator_config != Accelerator.Ethos_U55_32:
                 cycles *= 2
-        elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
+        elif query.npu_block_type == NpuBlockType.ConvolutionDepthWise:
             cycles = 4 * num_ublk_xy
-            if ifm_tensor.dtype.size_in_bits() == 16:
+            if query.ifm_bits == 16:
                 cycles *= 2
             num_kernel_steps = numeric_util.round_up_divide(num_kernel_elems, 4)
             cycles = max(cycles_wb, cycles) * num_kernel_steps * num_ublk_z
         elif (
-            (npu_block_type == NpuBlockType.ConvolutionMxN and block_traversal != TensorBlockTraversal.PartKernelFirst)
-            or npu_block_type == NpuBlockType.VectorProduct
-            or npu_block_type == NpuBlockType.ReduceSum
+            (query.npu_block_type == NpuBlockType.ConvolutionMxN and not query.config.is_partkernel)
+            or query.npu_block_type == NpuBlockType.VectorProduct
+            or query.npu_block_type == NpuBlockType.ReduceSum
         ):
             num_kernel_steps = num_kernel_elems
             cycles = max(cycles_wb, 4 * num_ublk_xy) * num_kernel_steps * num_ublk_z
         else:
-            assert block_traversal == TensorBlockTraversal.PartKernelFirst
-            divider = 2 if ifm_tensor.dtype.size_in_bits() == 16 else 4
+            assert query.config.is_partkernel
+            divider = 2 if query.ifm_bits == 16 else 4
             num_kernel_steps = numeric_util.round_up_divide(num_kernel_elems, divider)
             cycles = max(cycles_wb, 4 * num_ublk_xy) * (
                 num_kernel_steps * numeric_util.round_up_divide(ifm_block.depth, 8) * num_ublk_z
@@ -314,345 +360,199 @@
             if (num_ublk_x == 1 or num_ublk_y == 1) and num_ublk_z > 1 and use_acc_40bits:
                 delay_cycles += delay * num_ublk_z
         else:
-            delay = (
-                3
-                if use_acc_40bits and arch.accelerator_config in (Accelerator.Ethos_U55_64, Accelerator.Ethos_U55_128)
-                else 2
-            )
+            if use_acc_40bits and arch.accelerator_config in (Accelerator.Ethos_U55_64, Accelerator.Ethos_U55_128):
+                delay = 3
+            else:
+                delay = 2
+
             if num_ublk_x == 1 and num_ublk_y == 1:
                 if num_ublk_z == 1:
                     delay_cycles = delay * num_kernel_steps
                 elif num_kernel_steps > 1:
                     delay_cycles = delay * (num_kernel_steps - 1) * num_ublk_z
 
-        if npu_block_type == NpuBlockType.ConvolutionMxN and block_traversal == TensorBlockTraversal.PartKernelFirst:
+        if query.npu_block_type == NpuBlockType.ConvolutionMxN and query.config.is_partkernel:
             delay_cycles *= numeric_util.round_up_divide(ifm_block.depth, 8)
 
         cycles_dpu_blk += cycles
         cycles_dpu_blk += delay_cycles
 
-    if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
-        cycles_dpu_blk *= numeric_util.round_up_divide(ifm_tens_shape.depth, ifm_block.depth)
+    if query.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
+        cycles_dpu_blk *= numeric_util.round_up_divide(query.ifm_shape.depth, ifm_block.depth)
 
     cycles_dpu_blk /= arch.ncores
 
-    num_ofm_blk = (
-        numeric_util.round_up_divide(ofm_tens_shape.height, ofm_block.height)
-        * numeric_util.round_up_divide(ofm_tens_shape.width, ofm_block.width)
-        * numeric_util.round_up_divide(ofm_tens_shape.depth, ofm_block.depth)
-    )
+    # Estimate output cycles
+    num_ofm_blks = query.ofm_shape.div_round_up(ofm_block).elements()
+    cycles_output_blk = _estimate_output_cycles_per_element(arch, op_type, faf_type, query) * ofm_block.elements()
 
-    cycles_output_blk = estimate_output_cycles(
-        arch, npu_block_type, primary_op, num_elems_blk, ifm_tensor, ofm_tensor, use_acc_40bits
-    )
-
-    if scale_tensor:
+    # Scale and bias tensor
+    if query.const_shape.depth > 0:
         cycles_bias_blk = (
-            10
-            * min(ofm_block.depth, ofm_tens_shape.depth)
-            * arch.memory_latency[scale_tensor.mem_area][BandwidthDirection.Read]
-            / 256
+            10 * ofm_block.depth * arch.memory_latency[query.const_memory_area][BandwidthDirection.Read] / 256
         )
         cycles_output_blk = max(cycles_output_blk, cycles_bias_blk)
 
-    cycles_cmd = get_minimal_cmd_cycles(
-        arch,
-        ifm_tensor,
-        ofm_tensor,
-        ifm_block,
-        ofm_block,
-        cycles_dpu_blk,
-        ifm_tens_shape,
-        ofm_tens_shape,
-        cycles_output_blk,
-    )
+    ifm_blk_cycles, ofm_blk_cycles = _estimate_minimum_memory_cycles(arch, query)
+    cycles_cmd = ifm_blk_cycles + ofm_blk_cycles
+    cycles_cmd = (cycles_cmd + cycles_output_blk + cycles_dpu_blk) / 4  # per DPU
+
     cycles_dpu_blk = max(cycles_dpu_blk, cycles_cmd)
     cycles_output_blk = max(cycles_output_blk, cycles_cmd)
 
     if cycles_dpu_blk > cycles_output_blk:
-        total_cycles = cycles_dpu_blk * num_ofm_blk + cycles_output_blk
+        total_cycles = cycles_dpu_blk * num_ofm_blks + cycles_output_blk
     else:
-        total_cycles = cycles_output_blk * num_ofm_blk + cycles_dpu_blk
+        total_cycles = cycles_output_blk * num_ofm_blks + cycles_dpu_blk
 
     return total_cycles
 
 
-def estimate_memory_transfer_efficiency(
-    arch, mem_area, direction, tensor, block_size: Block, replace_bw=None, shape4D=None
-):
-    if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16):
-        return tensor.bandwidth() if replace_bw is None else replace_bw
+def measure_mem2mem_cycles(arch, from_mem_area, to_mem_area, to_transfer):
+    from_cycles = to_transfer // arch.memory_bandwidths_per_cycle[from_mem_area]
+    to_cycles = to_transfer // arch.memory_bandwidths_per_cycle[to_mem_area]
+    return max(from_cycles, to_cycles)
 
-    # Estimate memory transfer efficiency by calculating the burst length
-    # this is related to data format, block shape, and tensor shape, etc.
-    burst_len = 0
-    elem_size = tensor.dtype.size_in_bytes()
-    is_ifm = direction == BandwidthDirection.Read
-    tens = tensor.clone()
 
-    if not tensor.needs_linear_format:
-        tens.set_format(TensorFormat.NHCWB16, arch)
-    strides = tens.get_strides(shape4D=shape4D)
+def measure_cycle_cost(arch, op_type: Op, faf_type: Op, query: PerformanceQuery):
+    cycles = CycleCost()
 
-    if tens.format == TensorFormat.NHCWB16:
-        if strides[1] == block_size.depth:
-            burst_len = elem_size * block_size.depth * block_size.width
-        elif is_ifm:
-            burst_len = 16 * elem_size * block_size.width
+    # Convolution/Vector product cycle calculation
+    if query.npu_block_type in (
+        NpuBlockType.ConvolutionMxN,
+        NpuBlockType.ConvolutionDepthWise,
+        NpuBlockType.VectorProduct,
+        NpuBlockType.Pooling,
+        NpuBlockType.ReduceSum,
+    ):
+        # cycles.op_macs and cycles.op_cycles should both handle >32-bits
+        if query.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
+            cycles.op_macs = int(query.kernel.elements_wh()) * 1 * int(query.ofm_shape.elements())
         else:
-            burst_len = 16 * elem_size * block_size.width * arch.ncores
+            cycles.op_macs = (
+                int(query.kernel.elements_wh()) * int(query.ifm_shape.depth) * int(query.ofm_shape.elements())
+            )
+
+        cycles.op_cycles = int(_estimate_conv_cycles(arch, op_type, faf_type, query))
+    # Elementwise cycle calculation
+    elif query.npu_block_type == NpuBlockType.ElementWise:
+        cycles.op_macs = 0
+        cycles.op_cycles = int(_estimate_output_cycles_per_element(arch, op_type, faf_type, query)) * int(
+            query.ofm_shape.elements()
+        )
     else:
-        assert tens.format == TensorFormat.NHWC
-        if is_ifm:
-            if strides[3] == block_size.depth:
-                burst_len = elem_size * block_size.depth * block_size.width
-            else:
-                burst_len = elem_size * block_size.depth
-        else:
-            if block_size.depth <= 16 and strides[3] == block_size.depth:
-                burst_len = elem_size * block_size.depth * block_size.width
-            else:
-                burst_len = min(64, 16 * elem_size * arch.ncores, block_size.depth * elem_size)
+        assert False
 
-    burst_len = min(arch.memory_burst_length[mem_area], burst_len)
-    bw = tens.bandwidth() if replace_bw is None else replace_bw
-
-    return bw * (arch.memory_burst_length[mem_area] / burst_len)
+    return cycles
 
 
-def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None, force_outputs_to_fast_storage=False):
-    if block_config is None:
-        block_config = ps.block_config
-    bws = make_bandwidth_array()
-    scaled_bws = make_bandwidth_array()  # scaled bw with memory transfer efficiency
-    macs = 0
-    cycles = make_cycles_array()
-    ifm_read_multiple = 1
-    weight_read_multiple = 0
+def measure_element_access(arch, query: PerformanceQuery):
+    access = ElementAccess()
 
-    if ps.placement in (PassPlacement.MemoryOnly, PassPlacement.StartupInit):
-        return bws, macs, cycles, ifm_read_multiple, weight_read_multiple  # nothing real happening in this pass
+    ifm_block = Shape4D.min(query.ifm_shape, query.config.ifm_block)
+    ofm_block = Shape4D.min(query.ofm_shape, query.config.ofm_block)
+    ifm_rounding = Shape4D(list(arch.storage_rounding_quantums[query.ifm_format]))
 
-    explicit_padding = (0, 0, 0, 0)
-    primary_op = ps.primary_op
-    replacement_read_bws = {}
-    ofm_block = Block(block_config[1], block_config[0], block_config[3])
-    ifm_block = Block(block_config[1], block_config[0], block_config[3])
+    # Number of ofm blocks in the overall output shape
+    ofm_blocks = query.ofm_shape.div_round_up(ofm_block)
+    ofm_block_depth = ofm_block.depth
+    if query.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
+        ofm_blocks = ofm_blocks.with_depth(1)
+        ofm_block_depth = query.ifm_shape.depth
 
-    if ps.placement == PassPlacement.Npu and primary_op:
-        explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
-        assert primary_op.type.npu_block_type == ps.npu_block_type
-        npu_block_type = primary_op.type.npu_block_type
+    # Convolution & pooling
+    if query.npu_block_type in (
+        NpuBlockType.ConvolutionMxN,
+        NpuBlockType.ConvolutionDepthWise,
+        NpuBlockType.VectorProduct,
+        NpuBlockType.Pooling,
+        NpuBlockType.ReduceSum,
+    ):
+        # Number of sub kernels
+        sub_kernel_limits = arch.sub_kernel_limits[query.npu_block_type]
+        subkernels = numeric_util.round_up_divide(query.kernel.width, sub_kernel_limits[0])
+        subkernels *= numeric_util.round_up_divide(query.kernel.height, sub_kernel_limits[1])
 
-        ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
-        ifm_tensor_shape = ps.primary_op.ifm_shapes[0]
-        ofm_tensor_shape = ps.primary_op.ofm_shapes[0]
-        ofm_block.width = min(ofm_block.width, ofm_tensor_shape.width)
-        ofm_block.height = min(ofm_block.height, ofm_tensor_shape.height)
-        ofm_block.depth = min(ofm_block.depth, ofm_tensor_shape.depth)
+        ofm_block_count = ofm_blocks.elements()
 
-        if npu_block_type == NpuBlockType.ReduceSum:
-            block_traversal = TensorBlockTraversal.DepthFirst
-        elif npu_block_type in (
-            NpuBlockType.ConvolutionMxN,
-            NpuBlockType.ConvolutionDepthWise,
-            NpuBlockType.VectorProduct,
-        ):
-            block_traversal = weight_tensor.block_traversal
-        else:
-            block_traversal = TensorBlockTraversal.Default
-        ifm_block_depth = get_ifm_block_depth(
-            npu_block_type, ifm_tensor_shape.depth, ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
-        )
-        ifm_block = arch.get_ifm_block_size(
-            ifm_block_depth, ofm_block, primary_op.kernel, ifm_resampling_mode=ifm_tensor.resampling_mode
-        )
-        ifm_block.width = min(ifm_block.width, ifm_tensor_shape.width)
-        ifm_block.height = min(ifm_block.height, ifm_tensor_shape.height)
-
-        if npu_block_type in (
-            NpuBlockType.ConvolutionMxN,
-            NpuBlockType.ConvolutionDepthWise,
-            NpuBlockType.VectorProduct,
-            NpuBlockType.Pooling,
-            NpuBlockType.ReduceSum,
-        ):
-            # extent the ifm to full dimension
-
-            batch_size = ifm_tensor_shape.batch
-
-            # add in padding, height += top and bottom, width  += left and right
-            ifm_tensor_shape = ifm_tensor_shape.add(
-                0, explicit_padding[0] + explicit_padding[2], explicit_padding[1] + explicit_padding[3], 0
-            )
-
-            if npu_block_type != NpuBlockType.Pooling:
-                if npu_block_type == NpuBlockType.ReduceSum:
-                    weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
-                    weight_tensor_bandwidth_shape = [0] * 4
-                    weight_tensor_element_size = 0
-                    weight_tensor_bandwidth_compression_scale = 0.0
-                else:
-                    # For Vector product, weight format of IO is extended to HWIO, with H=W=1
-                    weight_tensor_shape = numeric_util.full_shape(4, weight_tensor.shape, 1)
-                    weight_tensor_bandwidth_shape = numeric_util.full_shape(4, weight_tensor.bandwidth_shape, 1)
-                    weight_tensor_element_size = weight_tensor.element_size()
-                    weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
-
-                nn_ops = (
-                    int(ofm_tensor_shape.batch)
-                    * int(ofm_tensor_shape.height)
-                    * int(ofm_tensor_shape.width)
-                    * int(weight_tensor_shape[0])
-                    * int(weight_tensor_shape[1])
-                    * int(weight_tensor_shape[2])
-                    * int(weight_tensor_shape[3])
-                )
-            else:
-                weight_tensor_shape = [
-                    *primary_op.get_kernel_size(),
-                    1,
-                    ifm_tensor_shape.depth,
-                ]
-                weight_tensor_bandwidth_shape = weight_tensor_shape
-                weight_tensor_element_size = 0
-                weight_tensor_bandwidth_compression_scale = 0.0
-                nn_ops = 0  # pooling doesn't count as NN ops
-
-            kernel_dims = weight_tensor_shape[:2]
-
-            sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
-            # count the sub kernels; the IFM block needs to be refetched for each of them
-            n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
-            n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
-            n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
-
-            n_full_depth_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
-            if npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
-                n_full_depth_stages = 1  # force to no reread
-
-            ifm_read_multiple = n_sub_kernels * n_full_depth_stages
-            replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth() * ifm_read_multiple
-
-            weight_read_multiple = numeric_util.round_up_divide(
-                ofm_tensor_shape.height, ofm_block.height
-            ) * numeric_util.round_up_divide(ofm_tensor_shape.width, ofm_block.width)
-            replacement_read_bws[weight_tensor] = (
-                batch_size
-                * shape_num_elements(weight_tensor_bandwidth_shape)
-                * weight_tensor_element_size
-                * weight_tensor_bandwidth_compression_scale
-                * weight_read_multiple
-            )
-
-            macs += nn_ops
-            cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
-                arch,
-                npu_block_type,
-                primary_op,
-                ifm_block,
-                ofm_block,
-                block_traversal,
-                kernel_dims,
-                ifm_tensor,
-                ofm_tensor,
-                ps.scale_tensor,
-            )
-        elif npu_block_type == NpuBlockType.ElementWise:
-            # Work out how many elements we have and calculate performance.
-            cycles[PassCycles.Npu] = estimate_output_cycles(
-                arch,
-                npu_block_type,
-                primary_op,
-                ofm_tensor.elements(),
-                ps.ifm_tensor,
-                ps.ofm_tensor,
-                None,
-                ps.ifm2_tensor,
-                ofm_block,
-            )
-
-        prev_npu_pass = next((npu_ps for npu_ps in ps.dag_predecessors if npu_ps.placement is PassPlacement.Npu), None)
-        if prev_npu_pass is None:
-            # cycles for DMA ops in first pass
-            dma_ops = (op for op in ps.ops if op.type == Op.DMA)
-            for dma_op in dma_ops:
-                mem_area = dma_op.attrs["source"]
-                for tens in dma_op.inputs:
-                    cycles[PassCycles.Npu] += tens.storage_size() / arch.memory_bandwidths_per_cycle[mem_area]
-
-    if rewrite_list is not None:
-        # apply the desired rewrites
-        for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
-            if ps != ps_to_rewrite:
-                continue
-            if rewrite_op == SchedulerRewrite.Nop:
-                pass  # these are fine, no bandwidth changes
-            elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
-                bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
-                if tens.purpose == TensorPurpose.FeatureMap:
-                    scaled_bw = estimate_memory_transfer_efficiency(
-                        arch,
-                        arch.fast_storage_mem_area,
-                        BandwidthDirection.Read,
-                        tens,
-                        ifm_block,
-                        replacement_read_bws[tens],
-                    )
-                else:
-                    scaled_bw = replacement_read_bws[tens]
-                scaled_bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += scaled_bw
-                replacement_read_bws[tens] = 0
-
-    for tens in ps.outputs:
-        if force_outputs_to_fast_storage:
-            bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
-            scaled_bws[arch.fast_storage_mem_area][tens.purpose][
-                BandwidthDirection.Write
-            ] += estimate_memory_transfer_efficiency(
-                arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0],
-            )
-        else:
-            bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
-            scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_transfer_efficiency(
-                arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0]
-            )
-
-    for tens in ps.intermediates:
-        bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
-        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
-
-        if tens in replacement_read_bws:
-            bw = replacement_read_bws[tens]
-        else:
-            bw = tens.bandwidth()
-
-        bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
-        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
-
-    for tens in ps.inputs:
-        if tens in replacement_read_bws:
-            bw = replacement_read_bws[tens]
-        else:
-            bw = tens.bandwidth()
-
-        bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
-
-        op_shape = None
-        if ps.placement == PassPlacement.Npu and primary_op:
-            if tens == ps.ifm_tensor:
-                op_shape = ps.ifm_shapes[0]
-            elif tens == ps.ifm2_tensor:
-                op_shape = ps.ifm_shapes[1]
-
-        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_transfer_efficiency(
-            arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw, op_shape
+        ifm_fetch = (
+            Shape4D.round_up(ifm_block, ifm_rounding).elements_wh()
+            * Shape4D.round_up(query.ifm_shape, ifm_rounding).depth
         )
 
-    # quick build access counts for only current pass, even though these aren't the final numbers
-    update_summary_cycles(arch, scaled_bws, cycles)
+        if query.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
+            kernel_read = query.kernel.elements_wh() * 1  # force to no reread
+        else:
+            kernel_read = query.kernel.elements_wh() * query.ifm_shape.depth
 
-    return bws, macs, cycles, ifm_read_multiple, weight_read_multiple
+        weight_fetch = kernel_read * ofm_block_depth * ofm_block_count
+
+        access.ifm_read[0] = ifm_fetch * subkernels * ofm_block_count
+
+        if query.npu_block_type not in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
+            access.const_read[0] = weight_fetch
+            access.const_read[1] = query.ofm_shape.depth  # Scales & biases
+            access.weights_refetch = ofm_blocks.elements_wh()
+    # Elementwise
+    elif query.npu_block_type == NpuBlockType.ElementWise:
+        if query.ifm_shape.elements() == 1:
+            if query.ifm_bits > 8:
+                # ifm is a non 8-bit scalar
+                access.ifm_read[0] = Shape4D.round_up(query.ifm_shape, ifm_rounding).elements()
+            if query.ifm2_shape:
+                access.ifm_read[1] = Shape4D.round_up(query.ofm_shape, ifm_rounding).elements()
+        else:
+            access.ifm_read[0] = Shape4D.round_up(query.ofm_shape, ifm_rounding).elements()
+            if query.ifm2_shape:
+                if query.ifm2_shape.elements() > 1:
+                    access.ifm_read[1] = Shape4D.round_up(query.ofm_shape, ifm_rounding).elements()
+                elif query.ifm2_bits > 8:
+                    # ifm2 is a non 8-bit scalar
+                    access.ifm_read[1] = Shape4D.round_up(query.ifm2_shape, ifm_rounding).elements()
+    # Unknown
+    else:
+        assert False
+
+    ofm_rounding = Shape4D(list(arch.storage_rounding_quantums[query.ofm_format]))
+    access.ofm_write = Shape4D.round_up(query.ofm_shape, ofm_rounding).elements()
+    return access
+
+
+def measure_performance_cost(
+    arch, op_type: Op, faf_type: Op, query: PerformanceQuery, offset: Shape4D, sub_shape: Shape4D
+):
+    assert (query.ofm_bits > 0) and (query.ifm_bits > 0)
+    assert query.ofm_shape.elements() != 0
+
+    # Default to start if no offset provided
+    if offset is None:
+        offset = Shape4D(0, 0, 0, 0)
+
+    # Default to entire area if no sub-shape provided
+    if sub_shape is None:
+        sub_shape = query.ofm_shape
+    else:
+        sub_shape = Shape4D.min(sub_shape, query.ofm_shape)
+
+    sub_query = copy.deepcopy(query)
+    sub_query.ofm_shape = query.ofm_shape.clip(offset, sub_shape)
+
+    access = ElementAccess()
+    cycles = CycleCost()
+
+    cycle_tmp = measure_cycle_cost(arch, op_type, faf_type, sub_query)
+    cycles += cycle_tmp
+    access = measure_element_access(arch, sub_query)
+
+    return access, cycles
+
+
+def make_bandwidth_array():
+    return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
+
+
+def make_cycles_array():
+    return np.zeros(PassCycles.Size)
 
 
 def update_summary_cycles(arch, bws, cycles):
@@ -669,42 +569,169 @@
     return cycles
 
 
-def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
-    return bws, macs, cycles
+def estimate_full_op_performance(
+    arch, schedule: Schedule, op: SchedulerOperation, prev_op: SchedulerOperation, block_config
+):
+    cycles_a = make_cycles_array()
+    bws = make_bandwidth_array()
+    scaled_bws = make_bandwidth_array()  # scaled bw with memory transfer efficiency
+    macs = 0
+
+    query = PerformanceQuery(op.op_type.npu_block_type)
+    query.ifm_shape = op.ifm.shape
+    query.ifm_format = op.ifm.format
+    query.ifm_memory_area = op.ifm.mem_area
+    query.ifm_bits = op.ifm.dtype.size_in_bits()
+    query.ifm2_shape = op.ifm2 and op.ifm2.shape
+    query.ifm2_format = op.ifm2 and op.ifm2.format
+    query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
+    query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
+    query.ofm_shape = op.ofm.shape
+    query.ofm_memory_area = op.ofm.mem_area
+    query.ofm_bits = op.ofm.dtype.size_in_bits()
+    query.ofm_format = op.ofm.format
+    query.kernel = op.kernel
+    query.config = block_config
+
+    cost = schedule.cost_map[op]
+    prev_cost = schedule.cost_map[prev_op] if prev_op else None
+    if op.parent_op.bias:
+        query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
+        if cost.buffered_weight_tensor:
+            query.const_memory_area = cost.buffered_weight_tensor.mem_area
+        else:
+            query.const_memory_area = cost.npu_weights_tensor.mem_area
+
+    cycles = measure_cycle_cost(arch, op.op_type, op.parent_op.activation and op.parent_op.activation.op_type, query)
+    cycles_a[PassCycles.Npu] = cycles.op_cycles
+    macs = cycles.op_macs
+
+    access = measure_element_access(arch, query)
+
+    # How many NPU cycles are available under the previously executing
+    # operator for performing buffered DMA transfers
+    slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
+
+    # LUT Transfer
+    parent_op = op.parent_op
+    lut_transfer_cycles = 0
+    if parent_op.activation_lut:
+        lut_tensor = [tens for tens in parent_op.inputs if tens.purpose == TensorPurpose.LUT][0]
+        src_tensor = lut_tensor.src_tensor
+        if src_tensor and lut_tensor.mem_area != src_tensor.mem_area:
+            bw = src_tensor.storage_size()
+            lut_transfer_cycles = measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw)
+
+            bws[src_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
+            # LUT read from SHRAM TODO remove?
+            scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][
+                BandwidthDirection.Read
+            ] += _estimate_memory_transfer_efficiency(
+                arch,
+                True,
+                lut_tensor.mem_area,
+                lut_tensor.format,
+                lut_tensor.element_size(),
+                query.config.ifm_block,
+                Shape4D(lut_tensor.shape),
+                bw,
+            )
+
+    if cost.npu_weights_tensor and cost.buffered_weight_tensor:
+        # DMA Weight Transfer
+        sz = 0
+        # Get the size of the first DMA
+        for core in range(0, arch.ncores):
+            key = WeightKey(core, 0)
+            if key in cost.npu_weights_tensor.encoded_ranges:
+                weight_range = cost.npu_weights_tensor.encoded_ranges[key]
+                sz += round_up(weight_range.total_bytes, 16)
+
+        total_sz = len(cost.npu_weights_tensor.buffer)
+        bws[cost.npu_weights_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Read] += total_sz
+        bws[cost.buffered_weight_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Write] += total_sz
+
+        ws_first_transfer_cycles = measure_mem2mem_cycles(
+            arch, cost.npu_weights_tensor.mem_area, cost.buffered_weight_tensor.mem_area, sz
+        )
+
+        # Add cycles for Weight + Scale Transfer
+        cycles_a[PassCycles.Npu] = max(
+            cost.full_weight_transfer_cycles - slack_cycles + cost.slack_buffering_cycles,
+            cycles.op_cycles + max(ws_first_transfer_cycles - slack_cycles, 0),
+        )
+
+        # Add cycles for LUT Transfer
+        cycles_a[PassCycles.Npu] += lut_transfer_cycles
+    else:
+        # Add cycles for LUT Transfer
+        cycles_a[PassCycles.Npu] += max(lut_transfer_cycles - slack_cycles, 0)
+
+    # OFM write
+    ofm = op.parent_op.ofm
+    bw = access.ofm_write * ofm.element_size()
+    bws[query.ofm_memory_area][ofm.purpose][BandwidthDirection.Write] += bw
+    scaled_bws[ofm.mem_area][ofm.purpose][BandwidthDirection.Write] += _estimate_memory_transfer_efficiency(
+        arch, False, query.ofm_memory_area, ofm.format, query.ofm_bits, query.config.ofm_block, query.ofm_shape, bw
+    )
+
+    # IFM read
+    ifm = op.parent_op.ifm
+    bw = access.ifm_read[0] * ifm.element_size()
+    bws[ifm.mem_area][ifm.purpose][BandwidthDirection.Read] += bw
+    scaled_bws[ifm.mem_area][ifm.purpose][BandwidthDirection.Read] += _estimate_memory_transfer_efficiency(
+        arch, True, query.ifm_memory_area, ifm.format, query.ifm_bits, query.config.ifm_block, query.ifm_shape, bw
+    )
+    if query.ifm2_shape:
+        ifm2 = op.parent_op.ifm2
+        bw = access.ifm_read[1] * ifm2.element_size()
+        bws[ifm2.mem_area][ifm2.purpose][BandwidthDirection.Read] += bw
+        scaled_bws[ifm2.mem_area][ifm2.purpose][BandwidthDirection.Read] += _estimate_memory_transfer_efficiency(
+            arch,
+            True,
+            query.ifm2_memory_area,
+            ifm2.format,
+            op.ifm2.dtype.size_in_bits(),
+            query.config.ifm_block,
+            query.ifm2_shape,
+            bw,
+        )
+
+    # Weight read
+    if access.const_read[0] > 0:
+        # alignment not accounted for in bandwidth_compression_scale_approx
+        encoded_size_approx = (
+            cost.npu_weights_tensor.elements() - access.const_read[1] * op.parent_op.bias.element_size()
+        )
+        orig_weight_size = parent_op.weights.elements()
+        bandwidth_compression_scale_approx = encoded_size_approx / orig_weight_size
+        bw = access.const_read[0] * bandwidth_compression_scale_approx
+        bws[query.const_memory_area][TensorPurpose.Weights][BandwidthDirection.Read] += bw
+
+    if access.const_read[1] > 0:
+        # Scales & biases
+        bw = access.const_read[1] * op.parent_op.bias.element_size()
+        bws[query.const_memory_area][TensorPurpose.FSBias][BandwidthDirection.Read] += bw
+
+    update_summary_cycles(arch, scaled_bws, cycles_a)
+
+    return bws, macs, cycles_a
 
 
-def performance_for_cascaded_pass(arch, cps):
-    total_bws = make_bandwidth_array()
-    total_macs = 0
-    total_cycles = make_cycles_array()
-
-    for ps in cps.passes:
-        bws, macs, cycles, _, _ = performance_metrics_for_pass(arch, ps)
-        ps.bandwidths = bws
-        ps.macs = macs
-        ps.cycles = cycles
-        total_bws += bws
-        total_macs += macs
-        total_cycles += cycles
-
-    bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
-    cps.bandwidths = bws
-    cps.macs = macs
-    cps.cycles = cycles
-    return bws, macs, cycles
-
-
-def calc_performance_for_network(nng, arch):
+def calc_new_performance_for_network(nng, arch):
     total_bws = make_bandwidth_array()
     total_macs = 0
     total_cycles = np.zeros(PassCycles.Size)
 
     for sg in nng.subgraphs:
-        for cps in sg.cascaded_passes:
-            bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
+        prev_op = None
+        for sched_op in sg.sched_ops:
+            op_info = sg.schedule.cost_map[sched_op]
+            bws, macs, cycles = estimate_full_op_performance(arch, sg.schedule, sched_op, prev_op, op_info.block_config)
             total_bws += bws
             total_macs += macs
             total_cycles += cycles
+            prev_op = sched_op
 
     nng.bandwidths = total_bws
     nng.macs = total_macs
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index ad4d29c..39a7f21 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -42,10 +42,8 @@
 
 def copy_compressed_values_to_memory_tensor(memory_tensor, src_tensor):
     start_addr = src_tensor.address
-    for compressed_values in src_tensor.compressed_values:
-        end_addr = start_addr + len(compressed_values)
-        memory_tensor.values[start_addr:end_addr] = compressed_values
-        start_addr = end_addr
+    end_addr = src_tensor.address + src_tensor.storage_size()
+    memory_tensor.values[start_addr:end_addr] = src_tensor.buffer.copy()
 
 
 def copy_ifm_values_to_memory_tensor(memory_tensor, src_tensor):
@@ -94,31 +92,21 @@
         sg.scratch_fast_tensor = scratch_fast_tens
         sg.scratch_fast_tensor.shape[0] = 0
 
-    for cps in sg.cascaded_passes:
-        for ps in cps.passes:
-            if ps.placement == PassPlacement.Npu:
-                if ps.weight_tensor is not None:
-                    # For DMA ops, ps.weight_tensor is referring to the SRAM weight tensor and therefore the address
-                    # is pointing at the destination address of where the weights should be placed in SRAM.
-                    # This ensures that the Flash weight tensor is used instead and thus gets the correct address.
-                    if ps.weight_tensor.ops[0].type == Op.DMA:
-                        copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.weight_tensor.ops[0].inputs[0])
-                    else:
-                        copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.weight_tensor)
+    for sched_op in sg.sched_ops:
+        ifm_tensor, ifm2_tensor, _, _, _ = sched_op.parent_op.get_ifm_ifm2_weights_biases_ofm()
 
-                    if ps.scale_tensor.ops[0].type == Op.DMA:
-                        copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.scale_tensor.ops[0].inputs[0])
-                    else:
-                        copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.scale_tensor)
+        op_info = sg.schedule.cost_map[sched_op]
+        if op_info.npu_weights_tensor:
+            copy_compressed_values_to_memory_tensor(sg.flash_tensor, op_info.npu_weights_tensor)
 
-                if ps.lut_tensor is not None:
-                    copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.lut_tensor)
-                if ps.ifm_tensor is not None and ps.ifm_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
-                    copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.ifm_tensor)
-                if ps.ifm2_tensor is not None and (
-                    ps.ifm2_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast)
-                ):
-                    copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.ifm2_tensor)
+        if ifm_tensor and ifm_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+            copy_ifm_values_to_memory_tensor(sg.flash_tensor, ifm_tensor)
+        if ifm2_tensor and (ifm2_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast)):
+            copy_ifm_values_to_memory_tensor(sg.flash_tensor, ifm2_tensor)
+
+        if sched_op.parent_op.activation_lut:
+            copy_ifm_values_to_memory_tensor(sg.flash_tensor, sched_op.parent_ps.lut_tensor)
+
     sg.command_stream_tensor = make_memory_tensor(
         sg.name + "_command_stream", flash_area, MemType.Permanent_CPU, command_stream_size_bytes, True, arch
     )
diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py
index d596209..011765f 100644
--- a/ethosu/vela/numeric_util.py
+++ b/ethosu/vela/numeric_util.py
@@ -24,6 +24,10 @@
     return ((a + b - 1) // b) * b
 
 
+def round_down(a, b):
+    return (a // b) * b
+
+
 def round_up_divide(a, b):
     return (a + b - 1) // b
 
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index a5a58e8..6bd955d 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -53,13 +53,23 @@
     Kernel information for NPU operations
     """
 
-    def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
+    def __init__(
+        self,
+        w: int,
+        h: int,
+        stride_x: int = 1,
+        stride_y: int = 1,
+        dilation_x: int = 1,
+        dilation_y: int = 1,
+        valid_padding=False,
+    ):
         assert stride_x > 0 and stride_y > 0
         assert dilation_x > 0 and dilation_y > 0
         self.width = w
         self.height = h
         self.stride = PointXY(stride_x, stride_y)
         self.dilation = PointXY(dilation_x, dilation_y)
+        self.valid_padding = valid_padding
 
     def elements_wh(self) -> int:
         return self.width * self.height
@@ -70,6 +80,9 @@
     def area_height(self) -> int:
         return (self.height - 1) * self.dilation.y + 1
 
+    def dilation(self) -> PointXY:
+        return self.dilation
+
     def dilated_wh(self) -> Tuple[int, int]:
         """Returns the dilated kernel width/height"""
         return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
@@ -149,7 +162,6 @@
     Cumsum = OperatorInfo()
     Custom = OperatorInfo()  # Custom 3rd party operator, only used in CPU subgraphs
     CustomNpuOp = OperatorInfo()  # NPU custom operator, only used in CPU subgraphs
-    DMA = OperatorInfo()
     Delegate = OperatorInfo()
     Densify = OperatorInfo()
     DepthToSpace = OperatorInfo()
@@ -422,6 +434,7 @@
         "ofm_shapes",
         "rescale",
         "read_offsets",
+        "read_shapes",
         "rounding_mode",
         "low_precision_scaling",
         "write_offset",
@@ -455,6 +468,7 @@
         # (which overrides the ofm tensor's scale)
         self.rescale = None
         self.read_offsets: List[Shape4D] = [None, None]  # offset for [ifm, ifm2]
+        self.read_shapes: List[Shape4D] = [None, None]  # read shape for [ifm, ifm2]
         self.rounding_mode: Optional[NpuRoundingMode] = None
         # The Mean operator (implemented as a depthwise convolution) requires scaling
         # to be calculated differently in one case. In that case, this is set to True.
@@ -482,6 +496,7 @@
         res.scheduled_pass = self.scheduled_pass
         res.op_index = None  # not relevant as not part of input network
         res.read_offsets = list(self.read_offsets)
+        res.read_shapes = list(self.read_shapes)
         res.rounding_mode = self.rounding_mode
         res.low_precision_scaling = self.low_precision_scaling
 
@@ -788,3 +803,13 @@
                 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
             if ofm_tensor is not None:
                 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
+
+    def has_scaling(self):
+        scaled = True
+        for tensor in [self.ifm, self.ifm2, self.ofm]:
+            if tensor is not None:
+                if tensor.quantization is None:
+                    scaled = False
+                    break
+
+        return scaled
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 2a1903d..518b243 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -32,13 +32,12 @@
     Main = 1
     Post = 2
     Mac = 4
-    Dma = 8
-    ElementWise = 16
-    Npu = 32
-    Cpu = 64
-    StartupInit = 128
-    MemoryOnly = 256
-    PostFusingLimited = 512
+    ElementWise = 8
+    Npu = 16
+    Cpu = 32
+    StartupInit = 64
+    MemoryOnly = 128
+    PostFusingLimited = 256
 
 
 mac_main_ops = set(
@@ -87,7 +86,6 @@
 quantization_ops = set((Op.Dequantize, Op.Max, Op.Min))
 cpu_ops = set((Op.Softmax, Op.LRN, Op.Shape, Op.Pad, Op.AddN)) | quantization_ops
 
-npu_dma_ops = set((Op.DMA,))
 startup_init_ops = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,))
 
@@ -135,16 +133,6 @@
     ),
     (
         # ops_set
-        npu_dma_ops,
-        # incompatible_pack_flags
-        PassFlags.Cpu | PassFlags.MemoryOnly,
-        # flags_to_set
-        PassFlags.Npu | PassFlags.Dma,
-        # flags_to_clear
-        PassFlags.Empty,
-    ),
-    (
-        # ops_set
         startup_init_ops,
         # incompatible_pack_flags
         PassFlags.Npu | PassFlags.Cpu | PassFlags.MemoryOnly,
@@ -261,12 +249,6 @@
                                 assert ifm_tensor is not None, "IFM missing in {}".format(curr_op)
                                 assert ifm_tensor.purpose == TensorPurpose.FeatureMap
 
-                        if flags_to_set & PassFlags.Dma:
-                            # DMAs are special - Output buffers need to be preserved as intermediates,
-                            # if the pass consumes the results
-                            if tens is not None:
-                                reverse_intermediates.append(tens)
-
                         if operation_set is None:
                             print("Warning:", curr_op.type, "operation is unknown or unsupported, placing on CPU")
 
@@ -292,7 +274,7 @@
 
         is_element_wise = True
         for op in reverse_ops_list:
-            if op.type not in elem_wise_ops and op.type not in npu_dma_ops:
+            if op.type not in elem_wise_ops and op.type:
                 is_element_wise = False
                 break
 
@@ -335,11 +317,6 @@
             for inp in primary_op.inputs:
                 if inp is None:
                     continue
-                if len(inp.ops) == 1 and inp.ops[0].type == Op.DMA and inp.purpose == TensorPurpose.FeatureMap:
-                    src_op = inp.ops[0]
-                    if src_op in input_ops_list:
-                        inp = src_op.inputs[0]
-                        input_ops_list.remove(src_op)
                 add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
             input_ops_list.remove(primary_op)
 
@@ -349,9 +326,6 @@
                 add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
 
         name = ops_list[0].name
-        non_dma_ops = [op for op in ops_list if op.type != Op.DMA]
-        if non_dma_ops:
-            name = non_dma_ops[0].name
         ps = Pass(name, placement, is_element_wise, npu_block_type)
         ps.ops = ops_list
         ps.primary_op = primary_op
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 6db9fe3..2043127 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -53,11 +53,11 @@
 from .api import NpuRoundingMode
 from .api import NpuShape3D
 from .api import NpuTileBox
+from .architecture_allocator import ArchitectureBlockConfig
+from .architecture_allocator import try_block_config
 from .architecture_features import Accelerator
 from .architecture_features import ArchitectureFeatures
-from .architecture_features import Block
 from .architecture_features import create_default_arch
-from .architecture_features import SharedBufferArea
 from .architecture_features import SHRAMElements
 from .errors import VelaError
 from .ethos_u55_regs.ethos_u55_regs import acc_format
@@ -80,12 +80,10 @@
 from .register_command_stream_util import get_strides
 from .register_command_stream_util import get_wait_dependency
 from .register_command_stream_util import has_ifm2
+from .register_command_stream_util import shape3d_to_block
 from .register_command_stream_util import to_kernel
 from .register_command_stream_util import UNARY_ELEMWISE_OPS
 from .register_command_stream_util import Watermark
-from .shared_buffer_allocation import find_suitable_block_configs
-from .shared_buffer_allocation import shared_buffer_allocation_for_npu_op
-from .shared_buffer_allocation import SharedBufferAllocation
 
 
 class RegisterMachine:
@@ -521,56 +519,40 @@
 
 
 def generate_block_config(
-    emit: CommandStreamEmitter,
-    npu_op: NpuBlockOperation,
-    arch: ArchitectureFeatures,
-    shared_buffer: SharedBufferAllocation,
+    emit: CommandStreamEmitter, block_config: NpuShape3D,
 ):
     """Generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers"""
-    block_config = npu_op.block_config
-    assert block_config is not None, "block_config has not been set"
-    alloc = shared_buffer.try_block(Block(block_config.width, block_config.height, block_config.depth))
-    assert alloc is not None, f"Block config {block_config} does not fit, op: {npu_op.op_type}"
     emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config.height - 1)
     emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config.width - 1)
     emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config.depth - 1)
 
 
-def generate_shram_registers_elementwise(
-    emit: CommandStreamEmitter,
-    npu_op: NpuElementWiseOperation,
-    arch: ArchitectureFeatures,
-    shared_buffer: SharedBufferAllocation,
+def generate_shram_registers(
+    emit: CommandStreamEmitter, npu_op: NpuBlockOperation, arch_block_config: ArchitectureBlockConfig,
 ):
-    """Generates IB_END/IB_START/AB_START registers for elementwise operations"""
-    # For elementwise set the required SHRAM to be equal to the total size of available SHRAM
-    uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
-    shram_required = arch.available_shram_banks(uses_lut)
-
-    # Acc buffers not needed so set AB_START to size of SHRAM
-    emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required)
-    emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required)
+    """Generates IB_END/IB_START/AB_START/ACC_FORMAT registers"""
+    emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, arch_block_config.layout.ib_end)
+    emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch_block_config.layout.ab_start)
     if has_ifm2(npu_op):
-        # Set IFM2_IB_START to the latter half of the IB space
-        ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM]
-        emit.cmd0_with_param(
-            cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start,
-        )
-    emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
+        emit.cmd0_with_param(cmd0.NPU_SET_IFM2_IB_START, arch_block_config.layout.ib_start2)
+    emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[arch_block_config.acc_type])
 
 
-def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_buffer: SharedBufferAllocation):
-    """Generates IB_END/IB_START/AB_START registers for non-elementwise operations"""
-    emit.cmd0_with_param(
-        cmd0.NPU_SET_IFM_IB_END,
-        shared_buffer.bank_locations[SharedBufferArea.IFM] + shared_buffer.banks_required[SharedBufferArea.IFM],
-    )
-    emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shared_buffer.bank_locations[SharedBufferArea.Accumulators])
-    emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
+def get_block_config_for_npu_op(
+    arch, npu_op: NpuBlockOperation, npu_block_type: NpuBlockType, is_partkernel: bool, ifm_resampling: resampling_mode
+) -> Optional[ArchitectureBlockConfig]:
+    """
+    Given npu_op.block_config, returns a corresponding ArchitectureBlockConfig.
+    Returns None if the block_config does not fit.
+    """
 
 
-def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation:
+def get_arch_block_config(
+    npu_op: NpuBlockOperation, block_traversal: NpuBlockTraversal, arch: ArchitectureFeatures
+) -> ArchitectureBlockConfig:
     """Creates shared buffer allocation for the given operation"""
+    assert npu_op.block_config is not None, "block_config has not been set"
+    block_type = NpuBlockType.Default
     if isinstance(npu_op, NpuConv2DOperation):
         block_type = NpuBlockType.ConvolutionMxN
     elif isinstance(npu_op, NpuConvDepthWiseOperation):
@@ -582,7 +564,37 @@
     else:
         assert 0, "Unsupported operation"
     ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
-    return shared_buffer_allocation_for_npu_op(arch, npu_op, block_type, ifm_resampling_mode)
+    is_partkernel = block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
+    uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
+    lut_banks = 2 if uses_lut else 0
+    fms = [npu_op.ifm, npu_op.ofm]
+    if npu_op.ifm2 is not None:
+        fms.append(npu_op.ifm2)
+    all_fms_have_quant = not any(fm.quantization is None or fm.quantization.scale_f32 is None for fm in fms)
+    ifm_bits = npu_op.ifm.data_type.size_in_bits()
+    ifm_shape = shape3d_to_block(npu_op.ifm.shape)
+    if has_ifm2(npu_op):
+        ifm2_shape = shape3d_to_block(npu_op.ifm2.shape)
+    else:
+        ifm2_shape = None
+    uses_scalar = npu_op.ifm2_scalar is not None
+    block_config = shape3d_to_block(npu_op.block_config)
+    arch_block_config = try_block_config(
+        block_config,
+        arch,
+        block_type,
+        ifm_shape,
+        ifm2_shape,
+        uses_scalar,
+        ifm_bits,
+        is_partkernel=is_partkernel,
+        kernel=to_kernel(npu_op.kernel),
+        lut_banks=lut_banks,
+        scaled=all_fms_have_quant,
+        ifm_resampling=ifm_resampling_mode,
+    )
+    assert arch_block_config is not None, f"block_config {npu_op.block_config} does not fit"
+    return arch_block_config
 
 
 def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark):
@@ -617,12 +629,9 @@
     generate_weights(emit, npu_op.weights, arch)
     generate_biases(emit, npu_op.biases, arch)
     generate_activation(emit, npu_op.activation, npu_op.ofm)
-    shared_buffer = create_shared_buffer(npu_op, arch)
-    generate_block_config(emit, npu_op, arch, shared_buffer)
-    if isinstance(npu_op, NpuElementWiseOperation):
-        generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
-    else:
-        generate_shram_registers_non_elementwise(emit, shared_buffer)
+    arch_block_config = get_arch_block_config(npu_op, block_traversal, arch)
+    generate_block_config(emit, npu_op.block_config)
+    generate_shram_registers(emit, npu_op, arch_block_config)
 
 
 # -------------------------------------------------------------------
@@ -1025,10 +1034,10 @@
     Internal implementation of the public facing API for finding block configs.
     """
     if isinstance(npu_op, NpuBlockOperation):
+        # TODO: implement this function
         arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
-        shared_buffer = create_shared_buffer(npu_op, arch)
-        blocks = find_suitable_block_configs(arch, shared_buffer)
-        return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+        block = arch.ofm_ublock
+        return [NpuShape3D(height=block.height, width=block.width, depth=block.depth)]
     return []
 
 
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index ee70b7b..3751d88 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -76,6 +76,10 @@
     return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1)
 
 
+def shape3d_to_block(shape: NpuShape3D) -> Block:
+    return Block(shape.width, shape.height, shape.depth)
+
+
 # -------------------------------------------------------------------
 # ADDRESSING/STRIDES (helper functions)
 # -------------------------------------------------------------------
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 65d3313..00a4dfc 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -13,1156 +13,1059 @@
 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+#
 # Description:
-# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
+# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and
+# subdivisions for the Operators
 import copy
-import enum
-from functools import lru_cache
-
-import numpy as np
+from enum import auto
+from enum import IntEnum
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
 
 from . import live_range
 from . import npu_performance
-from . import stats_writer
+from . import tensor_allocation
+from . import weight_compressor
+from .architecture_allocator import ArchitectureBlockConfig
+from .architecture_allocator import find_block_config
+from .architecture_allocator import get_ifm_area_required
+from .architecture_allocator import to_upscale
+from .architecture_features import ArchitectureFeatures
+from .architecture_features import Block
+from .cascade_builder import CascadeBuilder
+from .cascade_builder import CascadeInfo
 from .data_type import DataType
-from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
 from .nn_graph import CascadedPass
+from .nn_graph import Graph
+from .nn_graph import Pass
 from .nn_graph import PassPlacement
-from .nn_graph import SchedulerRewrite
 from .nn_graph import SchedulingStrategy
-from .npu_performance import make_bandwidth_array
-from .npu_performance import make_cycles_array
-from .npu_performance import make_metrics_arrays
-from .npu_performance import PassCycles
+from .nn_graph import Subgraph
+from .numeric_util import round_down
+from .numeric_util import round_up
 from .operation import NpuBlockType
 from .operation import Op
-from .operation import Operation
-from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
-from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
+from .shape4d import Shape4D
 from .tensor import MemArea
 from .tensor import MemType
+from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 from .tensor import TensorSubPurpose
 
 
-class ParetoMetric(enum.Enum):
-    BwCycMem = 1
-    BwCycMemBlkH = 2
+def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D:
+    if tensor_format == TensorFormat.NHCWB16:
+        return shape.with_depth(round_up(shape.depth, 16))
+
+    return shape
+
+
+class OptimizationStrategy(IntEnum):
+    """Enum defining the different optimization strategies for the Scheduler"""
+
+    Size = auto()
+    Performance = auto()
 
     def __str__(self):
         return self.name
 
 
-class SchedulerOptions:
+class SchedulerOpInfo:
+    """Contains metadata about a SchedulerOperation that is unique to one Schedule"""
+
     def __init__(
         self,
-        use_cascading=True,
-        verbose_schedule=False,
-        verbose_pareto_frontier_schedules=False,
-        use_ifm_streaming=True,
-        pareto_metric=ParetoMetric.BwCycMem,
-        use_nhcwb16_between_cascaded_passes=True,
-        cache_bias_scale_tensor=True,
+        block_config: ArchitectureBlockConfig,
+        weights_size: int,
+        stripe_input: Shape4D,
+        stripe_input2: Optional[Shape4D],
+        stripe: Shape4D,
     ):
-        self.use_cascading = use_cascading
+        self.block_config = block_config
+        self.weights_size = weights_size
+        self.stripe_input = stripe_input
+        self.stripe_input2 = stripe_input2
+        self.stripe = stripe
+        self.cascade = 0  # Assigned by CascadeBuilder. 0 means not part of a cascade
+        self.time_index = None  # Set by update_op_memory_snapshot
+        self.ofm_depth_slices: List[int] = [0, stripe.depth]
+        self.npu_weights_tensor = None
+        self.buffered_weight_tensor = None
+        self.cycles = None
+        self.slack_buffering_cycles = 0
+        self.slack_buffering_memory = 0
+        self.full_weight_transfer_cycles = 0
+
+    def copy(self):
+        res = SchedulerOpInfo(self.block_config, self.weights_size, self.stripe_input, self.stripe_input2, self.stripe,)
+        res.cascade = self.cascade
+        return res
+
+    def __str__(self):
+        res = f"\t\tBlock Config = {self.block_config}\n"
+        res += f"\t\tOFM Block = {self.block_config.ofm_block}\n"
+        res += f"\t\tIFM Stripe   = {self.stripe_input}\n"
+        res += f"\t\tIFM2 Stripe  = {self.stripe_input2}\n"
+        res += f"\t\tOFM Stripe   = {self.stripe}\n"
+        res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
+        res += (
+            f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n"
+        )
+        res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
+        res += f"\t\tAssigned Cascade = {self.cascade}"
+        return res
+
+
+class SchedulerOptions:
+    """Contains options for the Scheduler"""
+
+    def __init__(
+        self, optimization_strategy, sram_target, verbose_schedule,
+    ):
+        self.optimization_strategy = optimization_strategy
+        self.optimization_sram_limit = sram_target
         self.verbose_schedule = verbose_schedule
-        self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
-        self.use_ifm_streaming = use_ifm_streaming
-        self.pareto_metric = pareto_metric
-        self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
-        self.cache_bias_scale_tensor = cache_bias_scale_tensor
 
-    def __str__(self):
-        return type(self).__name__ + ": " + str(self.__dict__)
+    def __str__(self) -> str:
+        return f"{type(self).__name__}: {str(self.__dict__)}"
 
     __repr__ = __str__
 
 
-class Strategy:
-    __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
+class SchedulerTensor:
+    def __init__(self, shape, dt, mem_area, _format):
+        self.dtype = dt
+        self.mem_area = mem_area
+        self.shape = shape
+        self.format = _format
+        self.connection = None
 
-    def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
-        self.strat = strat
-        self.param = param
-        self.passes = passes
-        self.block_configs = block_configs
-        self.rewrite_list = (
-            rewrite_list  # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
+
+class SchedulerOperation:
+    """Scheduler internal representation of 'Operation'
+    This class can be seen as a node within the Scheduler Graph representation
+    """
+
+    def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph):
+        self.arch = arch
+        self.parent_ps = ps
+        self.parent_op = ps.primary_op
+        self.name = ps.primary_op.name
+        self.op_type = ps.primary_op.type
+        self.activation = ps.primary_op.activation
+        self.kernel = ps.primary_op.kernel
+        self.resampling_mode = ps.primary_op.ifm.resampling_mode
+        self.uses_scalar = ps.primary_op.ifm2 is not None and (
+            ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
         )
-        self.bws = bws
-        self.macs = macs
-        self.cycles = cycles
-        self.sram_used = sram_used
+        self.ifm_ublock = arch.ifm_ublock
 
-    def __eq__(self, other):
-        if self.strat != other.strat:
-            return False
-        if self.param != other.param:
-            return False
-        if self.block_configs != other.block_configs:
-            return False
-        if self.passes != other.passes:
-            return False
-        if (self.bws != other.bws).any():
-            return False
-        if self.macs != other.macs:
-            return False
-        if (self.cycles != other.cycles).any():
-            return False
-        if self.sram_used != other.sram_used:
-            return False
-        return True
+        self.ifm = SchedulerTensor(ps.ifm_shapes[0], ps.ifm_tensor.dtype, ps.ifm_tensor.mem_area, ps.ifm_tensor.format,)
 
-    def empty(self):
-        return not self.passes
+        self.ifm2 = None
+        if ps.ifm2_tensor:
+            self.ifm2 = SchedulerTensor(
+                ps.ifm_shapes[1], ps.ifm2_tensor.dtype, ps.ifm2_tensor.mem_area, ps.ifm2_tensor.format,
+            )
 
-    def key(self):
-        return self.passes[-1]
+        self.ofm = SchedulerTensor(ps.ofm_shapes[0], ps.ofm_tensor.dtype, ps.ofm_tensor.mem_area, ps.ofm_tensor.format,)
 
-    def clone(self):
-        return Strategy(
-            self.strat,
-            self.param,
-            self.passes,
-            self.block_configs,
-            self.rewrite_list,
-            self.bws,
-            self.macs,
-            self.cycles,
-            self.sram_used,
-        )
+        # Input volume width and height required to produce the smallest possible stripe
+        self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input()
 
-    def __str__(self):
-        return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
-            self.strat,
-            self.passes,
-            self.rewrite_list,
-            self.bws,
-            self.macs,
-            self.cycles,
-            self.sram_used,
-        )
+        # Flags that marks whether this SchedulerOperation requires full IFM/OFM
+        self.requires_full_ifm = False
+        self.requires_full_ifm2 = False
+        self.requires_full_ofm = False
 
-    __repr__ = __str__
+        self.index = 0
 
+    def add_ifm_connection(self, conn: "Connection"):
+        """Add input connection to another SchedulerOperation or Subgraph Input"""
+        conn.consumers.append(self)
+        self.ifm.connection = conn
 
-class StrategySet:
-    __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
-
-    def __init__(self, strats=None):
-        if strats is None:
-            strats = dict()
-        self.strats = strats  # final pass in packed pass -> Strategy
-        self.bws, self.macs, self.cycles = make_metrics_arrays()
-        self.max_sram_used = 0
-        self.total_sram_used = 0
-
-    def update_statistics(self):
-        self.bws = make_bandwidth_array()
-        self.max_sram_used = 0
-        for ps, strat in self.strats.items():
-            self.bws += strat.bws
-            self.macs += strat.macs
-            self.cycles += strat.cycles
-            self.max_sram_used = max(self.max_sram_used, strat.sram_used)
-            self.total_sram_used += strat.sram_used
-
-    def clone_add_strategy(self, new_strat):
-        key = new_strat.key()
-        if key in self.strats:
-            assert new_strat == self.strats[key]
-            return self
+    def add_ifm2_connection(self, conn: "Connection"):
+        """Add input connection to another SchedulerOperation or Subgraph Input"""
+        if self.ifm2:
+            conn.consumers.append(self)
+            self.ifm2.connection = conn
         else:
-            new_strats = dict(self.strats)
-            new_strats[key] = new_strat
-            new_set = StrategySet(new_strats)
-            new_set.bws = self.bws + new_strat.bws
-            new_set.macs = self.macs + new_strat.macs
-            new_set.cycles = self.cycles + new_strat.cycles
-            new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
-            new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
-            return new_set
+            assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2"
 
-    def __eq__(self, other):
-        if (self.bws != other.bws).any():
-            return False
-        if self.macs != other.macs:
-            return False
-        if (self.cycles != other.cycles).any():
-            return False
-        if self.max_sram_used != other.max_sram_used:
-            return False
-        if self.total_sram_used != other.total_sram_used:
-            return False
-        if self.strats != other.strats:
-            return False
-        return True
+    def add_ofm_connection(self, conn: "Connection"):
+        """Add output connection to another SchedulerOperation or Subgraph Output"""
+        conn.producers.append(self)
+        self.ofm.connection = conn
+
+    def get_dependants(self):
+        """Returns a list of the Ops that depend on this Operation's OFM"""
+        return self.ofm.connection.consumers
+
+    def ifm_size_in_bytes(self) -> int:
+        """Returns size of the IFM in bytes"""
+        ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format)
+        return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+    def ifm2_size_in_bytes(self) -> int:
+        """Returns size of the IFM2 in bytes"""
+        if self.ifm2:
+            ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format)
+            return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+        return 0
+
+    def ofm_size_in_bytes(self) -> int:
+        """Returns size of the OFM in bytes"""
+        ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format)
+        return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+    def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo:
+        """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce"""
+        ifm_shape = self.ifm.shape
+        ifm2_shape = self.ifm2 and self.ifm2.shape
+        ofm_shape = stripe
+
+        if ofm_shape != self.ofm.shape:
+            # Striped Op - Need to calculate stripe input volume
+            stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe)
+            # Ensure stripe input volume is within the full IFM volume
+            stripe_input_h = min(stripe_input_h, self.ifm.shape.height)
+            stripe_input_w = min(stripe_input_w, self.ifm.shape.width)
+            ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w)
+
+            if self.ifm2:
+                stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height)
+                stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width)
+                ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w)
+
+        block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape)
+
+        scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape)
+        if self.parent_op.weights:
+            # Default full-depth weight encoding with no buffering
+            scheduler_op_info.npu_weights_tensor = weight_compressor.encode_weight_and_scale_tensor(
+                self.arch,
+                self.parent_op,
+                self.parent_op.weights,
+                self.parent_op.bias,
+                self.kernel,
+                block_config,
+                [0, self.ofm.shape.depth],
+            )
+
+        self.parent_ps.block_config = block_config.old_style_representation()
+        return scheduler_op_info
+
+    def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]:
+        """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'"""
+        ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list())
+
+        return get_ifm_area_required(ofm_shape_to_produce, self.kernel, to_upscale(self.resampling_mode))
+
+    def _calculate_min_stripe_input(self) -> Shape4D:
+        # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1)
+        min_stripe = self.ofm.shape.with_hw(1, 1)
+        return self._get_stripe_input_requirement(min_stripe)
+
+    def _get_block_config(
+        self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D
+    ) -> ArchitectureBlockConfig:
+        # Returns a block config and SHRAM layout
+        lut_banks = 2 if self.parent_op.activation_lut else 0
+        return find_block_config(
+            self.arch,
+            self.op_type.npu_block_type,
+            ofm_shape,
+            ifm_shape,
+            ifm2_shape,
+            uses_scalar,
+            self.ifm.dtype.size_in_bits(),
+            self.kernel,
+            lut_banks,
+            self.parent_op.has_scaling(),
+            self.resampling_mode,
+        )
+
+
+class Connection:
+    """Scheduler internal representation of a Tensor that connects two SchedulerOperations
+    This class can be seen as an edge within the Scheduler Graph representation
+    """
+
+    def __init__(self, tensor: Tensor):
+        self.parent_tens = tensor
+
+        # SchedulerOperation relationships
+        self.producers: List[SchedulerOperation] = []
+        self.consumers: List[SchedulerOperation] = []
 
     def __str__(self):
-        return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
-            self.max_sram_used,
-            list(ps.name for ps in self.strats),
-        )
+        return f"<Connection {self.parent_tens.name}>"
 
     __repr__ = __str__
 
 
-empty_strategy = Strategy(
-    SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), 0, make_cycles_array(), 0
-)
-INFINITY = 1e30
+class Schedule:
+    """Class that contains a solution of how to schedule an NPU subgraph and its cost"""
 
-ABORT_SEARCH = []
+    def __init__(self, sg: Subgraph, label: str):
+        self.sg = sg
+        self.label = label
+        self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {}
+        self.cascades: Dict[int, CascadeInfo] = {}
+        self.fast_storage_peak_usage = 0
+        self.memory_snapshot = None
+
+    @property
+    def name(self):
+        return f"{self.sg.name}_{self.label}"
 
 
-def flatten_list_of_lists(lstlst):
-    lst = []
-    for v in lstlst:
-        lst.extend(v)
-    return lst
+class Scheduler:
+    """Main class of the Vela Scheduling"""
 
-
-class DynamicProgrammingScheduler:
-    def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
+    def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions):
         self.nng = nng
         self.sg = sg
         self.arch = arch
-        self.sram_limit = sram_limit
-        self.options = copy.copy(options)
-        self.use_cascading = options.use_cascading
+        self.sched_ops: List(SchedulerOperation) = []
+        self.max_schedule = None
+        self.scheduler_options = options
 
-        if self.arch.feature_map_storage_mem_area != MemArea.Sram:
-            self.use_ifm_ofm_overlap = False  # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
-        else:
-            self.use_ifm_ofm_overlap = True
-
-        self.verbose_schedule = options.verbose_schedule
-        self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
-        self.mem_area = MemArea.Sram
-
-        self.bandwidth_weights = arch.bandwidth_weights
-        self.cycles_weight = arch.cycles_weight
-        self.max_sram_used_weight = arch.max_sram_used_weight
-
-        self.n_combinations_searched = 0
-
-        self.pareto_max_candidates = 16
-
-        self.ifm_stream_npu_blocks = set(
-            (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
-        )
-
-    num_pareto_metrics = 4
-    view_values = ",".join(["d"] * num_pareto_metrics)
-    order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
-
-    def pareto_metric(self, candidate):
-        strat, strat_set = candidate
-        total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
-        bws = strat.bws + strat_set.bws
-        last_block_height = 0
-        if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
-            last_block_height = strat.block_configs[-1][0]
-
-        return (
-            np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
-            strat_set.max_sram_used,
-            strat.sram_used,
-            last_block_height,
-        )
-
-    def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
-
-        candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
-
-        if len(candidates) <= 1:
-            return candidates
-        assert remove_equally_good_candidates
-        pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
-        ids = np.arange(len(candidates), dtype=np.int32)
-        for idx, cand in enumerate(candidates):
-            pareto_vals[idx] = self.pareto_metric(cand)
-
-        sort_order = np.argsort(
-            pareto_vals.view(DynamicProgrammingScheduler.view_values),
-            order=DynamicProgrammingScheduler.order_values,
-            axis=0,
-            kind="stable",
-        ).flatten()
-        pareto_vals = pareto_vals[sort_order]
-        ids = ids[sort_order]
-
-        pareto_frontier = []
-        while len(ids) > 0:
-            pareto_frontier.append(candidates[ids[0]])
-            not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
-            ids = ids[not_dominated_by_first]
-            pareto_vals = pareto_vals[not_dominated_by_first]
-
-        if len(pareto_frontier) > self.pareto_max_candidates:
-            pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
-            pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
-
-        return pareto_frontier
-
-    def candidate_metric(self, candidate):
-        strat, strat_set = candidate
-        max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
-        bws = strat.bws + strat_set.bws
-        total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
-
-        return (
-            max_sram_used * self.max_sram_used_weight
-            + np.tensordot(bws, self.bandwidth_weights, axes=3)
-            + total_cycles * self.cycles_weight
-        )
-
-    def sort_by_candidate_metric(self, candidate_list):
-        sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
-        return sorted_list
-
-    def best_candidate(self, candidate_list):
-        if len(candidate_list) == 0:
-            return ABORT_SEARCH
-        if len(candidate_list) == 1:
-            return candidate_list[0]
-        sorted_list = self.sort_by_candidate_metric(candidate_list)
-        return sorted_list[0]
-
-    def graduate_strat(self, strat_type, sram_used, old_strat_data):
-        res = []
-        for old_strat, old_strat_set in old_strat_data:
-            if old_strat.sram_used + sram_used > self.sram_limit:
-                continue  # This strategy is bad, drop it
-            if old_strat_set.max_sram_used > self.sram_limit:
-                continue  # This strategy is bad, drop it
-            assert old_strat.strat == SchedulingStrategy.Unknown
-
-            new_strat = old_strat.clone()
-            new_strat.strat = strat_type
-            new_strat.sram_used = old_strat.sram_used + sram_used
-
-            if self.use_ifm_ofm_overlap:
-                overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
-                    new_strat.strat, new_strat.passes, new_strat.block_configs
-                )
-                new_strat.sram_used -= overlap
-
-            new_strat_set = old_strat_set.clone_add_strategy(new_strat)
-            res.append((empty_strategy, new_strat_set))
-        return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-
-    def append_sram(self, sram_used, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            assert old_strat.sram_used == 0
-            new_strat = old_strat.clone()
-            new_strat.sram_used = old_strat.sram_used + sram_used
-
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            bws, macs, cycles = metrics[:3]
-
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.block_configs = old_strat.block_configs + [block_config]
-            new_strat.bws = old_strat.bws + bws
-            new_strat.macs = old_strat.macs + macs
-            new_strat.cycles = old_strat.cycles + cycles
-            new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
-                self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
-            )
-
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_pass_block_config_performance_metrics_rewrite_list(
-        self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
-    ):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            bws, macs, cycles = metrics[:3]
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.block_configs = old_strat.block_configs + [block_config]
-            new_strat.bws = old_strat.bws + bws
-            new_strat.macs = old_strat.macs + macs
-            new_strat.cycles = old_strat.cycles + cycles
-            new_strat.passes = old_strat.passes + [new_pass]
-            new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
-                self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
-            )
-            new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
-            res.append((new_strat, strat_set))
-        return res
-
-    def pass_to_strat(self, strat_data):
-        res = {}
-        for strat in strat_data[1].strats.values():
-            for ps in strat.passes:
-                res[ps] = strat
-        return res
-
-    def compatible_strats(self, a, b):
-        intersection = a.keys() & b.keys()
-        for k in intersection:
-            if a[k] != b[k]:
-                return False
-        return True
-
-    def collate_strats_for_passes(self, all_passes):
-        if len(all_passes) == 0:
-            return [(empty_strategy, StrategySet(dict()))]
-        if len(all_passes) == 1:
-            return all_passes[0]  # save some space in the common case
-        all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
-        prev_combos = [dict()]
-        for j, strand in enumerate(all_strands):
-            new_combos = []
-            for i, alt in enumerate(strand):
-                for prev in prev_combos:
-                    if self.compatible_strats(prev, alt):
-                        cmb = dict(prev)
-                        cmb.update(all_passes[j][i][1].strats)
-                        new_combos.append(cmb)
-            prev_combos = new_combos
-
-        res = []
-        for d in prev_combos:
-            s = StrategySet(d)
-            s.update_statistics()
-            res.append((empty_strategy, s))
-        return res
-
-    def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
-        # get the rest of the predecessors
-        other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
-        other_predecessor_data = self.search_pass_list(other_predecessors)
-
-        # pred strat data has an incomplete strategy, which we need
-        # to continue on, whereas the other ones have completed strategies.
-        # we need to merge these, but keep the incomplete strategy too.
-
-        res = []
-        for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
-            all_strats = [
-                [(empty_strategy, pred_pass_strat_set)],  # pred strat data but with a dummy empty strategy
-                other_predecessor_data,  # this one is fine to use as-is
-            ]
-            collated_strat_data = self.collate_strats_for_passes(all_strats)
-            strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
-            res.extend(strat_data)
-        return res
-
-    def calc_non_local_mem_usage(self):
-        ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
-        range_set = live_range.extract_live_ranges_from_passes(
-            self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
-        )
-        range_dict = range_set.ranges
-
-        # find which ranges overlap passes but aren't input/outputs of the passes.
-        # these won't be counted by the dynamic programming search and must be counted in manually.
-        end_pos = max(ps.time for ps in self.sg.passes) + 2
-        mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
-        non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
-
-        for tens, rng in range_dict.items():
-            storage_size = tens.storage_size()
-            assert tens.mem_area == self.mem_area
-            mem_usage[rng.start_time : rng.end_time] += storage_size
-
+    def create_scheduler_representation(self, arch: ArchitectureFeatures):
+        """Creates a Scheduler Graph representation"""
+        # Temporary dict for creating connections between the Operations
+        connections: Dict[Tensor, Connection] = {}
+        # Memory required for the largest FeatureMap that has to be full
+        min_memory_req = 0
         for ps in self.sg.passes:
-            local_mem_usage = 0
-            for tens in ps.inputs + ps.outputs + ps.intermediates:
-                if tens.mem_area != self.mem_area:
-                    continue
-
-                local_mem_usage += tens.storage_size()
-
-            non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
-
-        self.non_local_mem_usage = non_local_mem_usage
-
-    def search(self):
-        self.calc_non_local_mem_usage()
-        starting_passes = [ps for ps in self.sg.passes if not ps.successors]
-        strat_data = self.search_pass_list(starting_passes)
-
-        _, best_set = self.best_candidate(strat_data)
-
-        if self.verbose_pareto_frontier_schedules:
-            print(
-                "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
-                % (self.n_combinations_searched, len(strat_data))
-            )
-            for idx, (_, strat_set) in enumerate(strat_data):
-                extra = ""
-                if strat_set == best_set:
-                    extra = "(Best candidate)"
-                print("Candidate", idx, extra)
-                memory_used = {MemArea.Sram: strat_set.max_sram_used}
-                stats_writer.print_performance_metrics_for_strat(
-                    self.arch,
-                    "",
-                    strat_set.cycles,
-                    strat_set.macs,
-                    strat_set.bws,
-                    self.nng.batch_size,
-                    memory_used,
-                    len(self.sg.passes),
-                    len(strat_set.strats),
-                )
-
-        return best_set
-
-    def search_pass_list(self, pass_list):
-        all_strats = []
-        for ps in pass_list:
-            strat = self.search_output(ps)
-            all_strats.append(strat)
-        strat_data = self.collate_strats_for_passes(all_strats)
-        for strd in strat_data:
-            for ps in pass_list:
-                assert ps in strd[1].strats  # should have strategies for everything we asked to search
-        return strat_data
-
-    def search_predecessors(self, ps):
-
-        # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
-        # we have strats for all passes
-
-        pass_list = ps.dag_predecessors
-        strat_data = self.search_pass_list(pass_list)
-
-        return strat_data
-
-    @lru_cache(maxsize=None)
-    def search_output(self, ps):
-
-        assert ps in self.sg.passes
-        candidate_list = []
-
-        candidate_list.extend(self.search_weight_streaming_output(ps))
-
-        if self.options.use_ifm_streaming:
-            candidate_list.extend(self.search_ifm_streaming_output(ps))
-
-        best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
-
-        if not best:
-            print(
-                "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
-                % (ps.name,)
-            )
-            return self.search_predecessors(ps)
-
-        return best
-
-    def search_ifm_streaming_output(self, ps):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-        if ps.npu_block_type not in self.ifm_stream_npu_blocks:
-            return ABORT_SEARCH
-        strat_data = self.search_ifm_streaming_body(ps, False)
-
-        sram_used = self.non_local_mem_usage[ps.time]
-        for tens in ps.outputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
-
-    @lru_cache(maxsize=None)
-    def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-        if ps.npu_block_type not in self.ifm_stream_npu_blocks:
-            return ABORT_SEARCH
-        ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
-        res = []
-
-        base_sram_used = 0
-        for tens in ps.intermediates:
-            if tens.mem_area == self.mem_area:
-                if tens.purpose == TensorPurpose.Weights:
-                    base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
-                else:
-                    base_sram_used += tens.storage_size()
-
-        all_block_configs = self.get_block_configs(ps)
-        for block_config in all_block_configs:
-            all_strats = []
-
-            if self.use_cascading:
-                all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
-
-            all_strats.extend(ifm_input_search_resuls)
-
-            rewrite_list = []
-            sram_used = base_sram_used
-
-            metrics = npu_performance.performance_metrics_for_pass(
-                self.arch,
-                ps,
-                block_config,
-                rewrite_list=rewrite_list,
-                force_outputs_to_fast_storage=force_outputs_to_fast_storage,
-            )
-
-            res.extend(
-                self.append_sram_pass_block_config_performance_metrics_rewrite_list(
-                    sram_used, ps, block_config, metrics, rewrite_list, all_strats
-                )
-            )
-
-        self.n_combinations_searched += len(res)
-        res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-        return res
-
-    def avoid_for_cascading(self, pred_candidate):
-        for op in pred_candidate.ops:
-            if (
-                op.memory_function == Op.ConcatSliceWrite
-                and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
-            ):
-                # For SRAM spilling, concat op is avoided as predecessor
-                return True
-            if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
-                # The op has consumers in other subgraphs
-                return True
-        return False
-
-    def search_ifm_streaming_partial(self, ps, block_config):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-
-        if len(ps.inputs) < 1:
-            return ABORT_SEARCH
-
-        ifm_tensor = ps.ifm_tensor
-
-        if ifm_tensor is None:
-            return ABORT_SEARCH
-        if ifm_tensor.purpose != TensorPurpose.FeatureMap:
-            return ABORT_SEARCH
-        if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
-            return ABORT_SEARCH
-
-        pred_pass_list = []
-        for pred_candidate in ps.dag_predecessors:
-            if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
-                # we found a predecessor that produces this IFM tensor
-                if not ifm_tensor.needs_linear_format:
-                    # and NHCWB16 can be used
-                    if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
-                        # and it only has one successor, namely us
-                        if pred_candidate.placement == PassPlacement.Npu:
-                            if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
-                                # and it is on the Npu
-                                if not self.avoid_for_cascading(pred_candidate):
-                                    # and fusable - it's a candidate
-                                    pred_pass_list.append(pred_candidate)
-
-        if not pred_pass_list:
-            return ABORT_SEARCH
-
-        all_candidates = []
-        for pred_pass in pred_pass_list:
-            # recurse into the next pass
-            ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.arch.is_spilling_enabled())
-
-            strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
-            for strat_opt in strat_data:
-
-                pred_pass_block_config = strat_opt[0].block_configs[-1]
-                rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
-                    self.arch, pred_pass, pred_pass_block_config, ps, block_config
-                )
-                if rolling_buffer_dims is None:
-                    continue  # this does not pack properly, skip it.
-
-                sram_used = 0
-                for tens in ps.inputs:
-                    if tens != ifm_tensor:
-                        if tens.mem_area == self.mem_area:
-                            sram_used += tens.storage_size()
-
-                rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
-
-                rewrite_list = [
-                    (
-                        SchedulerRewrite.ChangeTensorSubPurpose,
-                        ifm_tensor,
-                        TensorSubPurpose.RollingBufferY,
-                        rolling_buffer_y,
-                        None,
-                        ps,
-                    )
-                ]
-                sram_used += ifm_tensor.storage_size_for_sub_purpose(
-                    self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
-                )
-
-                all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
-
-        self.n_combinations_searched += len(all_candidates)
-        return all_candidates
-
-    def get_block_configs(self, ps):
-        if ps.placement != PassPlacement.Npu:
-            return [(1, 1, 1, 1)]  # default
-
-        block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
-
-        # Take a limited number of the largest blocks
-        if self.arch.block_config_limit > 0:
-            # Sort by block area, followed by depth
-            block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
-            bound = min(len(block_configs), self.arch.block_config_limit)
-            # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
-            tmp = block_configs[:bound]
-            tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
-            block_configs = tmp
-
-        return block_configs
-
-    def search_ifm_streaming_input(self, ps):
-        sram_used = 0
-        for tens in ps.inputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.append_sram(sram_used, self.search_predecessors(ps))
-
-    def search_weight_streaming_output(self, ps):
-        strat_data = self.search_weight_streaming_body(ps)
-
-        sram_used = self.non_local_mem_usage[ps.time]
-        for tens in ps.outputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
-
-    @lru_cache(maxsize=None)
-    def search_weight_streaming_body(self, ps):
-
-        strat_data = self.search_weight_streaming_input(ps)
-
-        res = []
-
-        all_block_configs = self.get_block_configs(ps)
-
-        for block_config in all_block_configs:
-
-            sram_used = 0
-            rewrite_list = []
-
-            for tens in ps.intermediates:
-                if tens.mem_area == self.mem_area:
-                    if tens.purpose == TensorPurpose.Weights:
-                        sram_used += tens.storage_size_for_sub_purpose(
-                            self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
-                        )
-                        rewrite_list.append(
-                            (
-                                SchedulerRewrite.ChangeTensorSubPurpose,
-                                tens,
-                                TensorSubPurpose.DoubleBuffer,
-                                block_config[3],
-                                None,
-                                ps,
-                            )
-                        )
-                    else:
-                        sram_used += tens.storage_size()
-
-            metrics = npu_performance.performance_metrics_for_pass(
-                self.arch, ps, block_config, rewrite_list=rewrite_list
-            )
-
-            res.extend(
-                self.append_sram_pass_block_config_performance_metrics_rewrite_list(
-                    sram_used, ps, block_config, metrics, rewrite_list, strat_data
-                )
-            )
-
-        self.n_combinations_searched += len(res)
-        res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-        return res
-
-    def search_weight_streaming_input(self, ps):
-        sram_used = 0
-        for tens in ps.inputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.append_sram(sram_used, self.search_predecessors(ps))
-
-    def apply_result(self, strat_set, arch):
-        pass_to_cascaded_pass = dict()
-        for _, strat in strat_set.strats.items():
-            # rewrite the tensors that need this first. e.g. make rolling buffers
-            inputs = []
-            intermediates = []
-            outputs = []
-
-            for ps in strat.passes:
-                inputs += ps.inputs
-                intermediates += ps.intermediates
-                outputs += ps.outputs
-
-            for tens in set(inputs) & set(outputs):
-                # tensors that are in both sets are intermediates
-
-                # find pass with input/output tensor, and check if they are both placed on NPU
-                input_placement = None
-                output_placement = None
-                for ps in strat.passes:
-                    if tens in ps.inputs:
-                        input_placement = ps.placement
-                    if tens in ps.outputs:
-                        output_placement = ps.placement
-                if input_placement == output_placement == PassPlacement.Npu:
-                    tens.set_format(TensorFormat.NHCWB16, arch)
-
-                intermediates.append(tens)
-                inputs.remove(tens)
-                outputs.remove(tens)
-
-            for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
-                if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
-                    tens.mem_area = self.arch.fast_storage_mem_area
-                    tens.mem_type = MemType.Scratch_fast
-                    tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
-                else:
-                    assert 0, "unknown rewrite_op " + str(rewrite_op)
-
-            is_element_wise = True
-            for ps in strat.passes:
-                assert ps.placement == strat.passes[0].placement
-                if not ps.is_element_wise:
-                    is_element_wise = False
-                    break
-
-            cascaded_pass = CascadedPass(
-                strat.passes[0].name,
-                strat.strat,
-                inputs,
-                intermediates,
-                outputs,
-                strat.passes,
-                strat.passes[0].placement,
-                is_element_wise,
-            )
-            assert strat.sram_used >= 0
-            cascaded_pass.sram_used = strat.sram_used
-
-            for idx, ps in enumerate(strat.passes):
-                assert ps not in pass_to_cascaded_pass
-                pass_to_cascaded_pass[ps] = cascaded_pass
-                ps.cascade = cascaded_pass
-                ps.block_config = strat.block_configs[idx]
-
-                if ps.placement == PassPlacement.Npu:
-                    ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
-                        self.arch, ps, ps.block_config
-                    )
-                    assert ps.shared_buffer is not None
-
-                sram_used = max(self.non_local_mem_usage[ps.time], 0)
-                for op in ps.ops:
-                    subgraph = op.attrs.get("subgraph")
-                    if subgraph:
-                        subgraph.base_sram_used = sram_used
-
-        # all passes should have a cascaded pass now
-        if len(pass_to_cascaded_pass) != len(self.sg.passes):
-            print(
-                "mismatch: we have %d passes, but only %d have cascaded passes associated"
-                % (len(self.sg.passes), len(pass_to_cascaded_pass))
-            )
-            for ps in self.sg.passes:
-                if ps not in pass_to_cascaded_pass:
-                    print("%3d pass missing cascaded pass %s" % (ps.time, ps))
-
-            assert len(pass_to_cascaded_pass) == len(self.sg.passes)
-
-        cascaded_passes = []
-        if self.sg.placement == PassPlacement.Cpu:
-            # Retain the pass order for CPU subgraph
-            cascaded_passes = [ps.cascade for ps in self.sg.passes]
-        else:
-            # we have all the passes, but we need to put them in order and build predecessor/successor links.
-            visit_pass_set = set()
-
-            def visit_pass(ps):
-                if ps in visit_pass_set:
-                    return
-                visit_pass_set.add(ps)
-
-                cps = ps.cascade
-                dont_traverse = set(cps.passes)
-
-                for ps in cps.passes:
-                    for pred in ps.predecessors:
-                        if pred in dont_traverse:
-                            continue
-                        visit_pass(pred)
-
-                cascaded_passes.append(cps)
-
-            starting_passes = [ps for ps in self.sg.passes if not ps.successors]
-            for ps in starting_passes:
-                visit_pass(ps)
-
-        # reorder so startup init cascaded passes come first
-        def is_startup_cascaded_pass(cps):
-            if not cps.passes:
-                return False
-            return cps.placement == PassPlacement.StartupInit
-
-        cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
-            cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
-        ]
-
-        self.sg.cascaded_passes = cascaded_passes
-        self.sg.build_cascaded_pass_links()
-
-        # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
-        # (NHCWB16 within cascaded passes has been handled earlier in this function)
-        if self.sg.placement == PassPlacement.Npu:
-            # Dictionary tensor -> list of ops, containing feature maps that can be attempted
-            # to be moved to fast storage
-            fast_storage_tensor_rewrites = {}
-            last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
-            # Memory only passes have no primary_op, so use the last op in ops
-            if last_op_in_subgraph is None:
-                last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1]
-            for ps in self.sg.cascaded_passes:
-                if ps.placement != PassPlacement.Npu:
-                    continue
+            if ps.primary_op:
+                # Set tensor format to NHCWB16 for output FeatureMaps, if possible
                 for output in ps.outputs:
                     if output.purpose != TensorPurpose.FeatureMap:
                         continue
-
-                    use_NHCWB16 = not output.needs_linear_format
-                    use_fast_storage = True
-                    rewrites = []
-                    for op in output.consumer_list:
-                        if op is None:
-                            use_NHCWB16 = False
-                            use_fast_storage = False
-                            continue
-                        if op.type == Op.ReduceSum and output.dtype == DataType.int32:
-                            use_NHCWB16 = False
-                        elif op.type == Op.Reshape:
-                            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
-                            # consumers do not also need to perform a reshape or if the OFM is going to
-                            # be processed by CPU operations. No-op reshape consumers with empty lists
-                            # (those that have no consumers, or null-consumers used as list terminators)
-                            # must use normal NHWC output.
-                            def incompatible_consumers(oper):
-                                if oper and oper.type == Op.Reshape:
-                                    for consumer in oper.outputs[0].consumer_list:
-                                        yield from incompatible_consumers(consumer)
-                                yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph
-
-                            if not any(incompatible_consumers(op)):
-
-                                def get_rewrites(oper):
-                                    if oper and oper.type == Op.Reshape:
-                                        for consumer in oper.outputs[0].consumer_list:
-                                            yield from get_rewrites(consumer)
-                                        yield oper
-
-                                rewrites.extend(get_rewrites(op))
-                                # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                                inshape = op.ifm_shapes[0]
-                                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
-                                use_NHCWB16 &= compatible_shape and all(compatible_shape)
-                            else:
-                                use_NHCWB16 = False
-                                use_fast_storage = False
-                        use_NHCWB16 &= op.run_on_npu
-                        use_fast_storage &= op.run_on_npu
-
-                    if use_fast_storage:
-                        fast_storage_tensor_rewrites[output] = rewrites
-                    if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+                    if not output.needs_linear_format:
                         output.set_format(TensorFormat.NHCWB16, arch)
-                        for rewrite_op in rewrites:
-                            rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
-            if arch.is_spilling_enabled():
-                # Remember feature maps that can be moved to fast storage for later use
-                # in use_fast_storage_for_feature_maps
-                self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
+
+                # Create SchedulerOperations
+                op = SchedulerOperation(ps, arch, self.nng)
+                op.index = len(self.sched_ops)
+
+                # Make connections
+                if ps.ifm_tensor not in connections:
+                    connections[ps.ifm_tensor] = Connection(ps.ifm_tensor)
+                if ps.ifm2_tensor and ps.ifm2_tensor not in connections:
+                    connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor)
+                if ps.ofm_tensor not in connections:
+                    connections[ps.ofm_tensor] = Connection(ps.ofm_tensor)
+
+                op.add_ifm_connection(connections[ps.ifm_tensor])
+                if ps.ifm2_tensor:
+                    op.add_ifm2_connection(connections[ps.ifm2_tensor])
+                op.add_ofm_connection(connections[ps.ofm_tensor])
+
+                # Set requirements on the ifm/ofm buffers
+                self.sched_ops.append(op)
+                if ps.ifm_tensor in self.sg.input_tensors:
+                    # This Op consumes a subgraph input
+                    op.requires_full_ifm = True
+                if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors:
+                    # This Op consumes a subgraph input
+                    op.requires_full_ifm2 = True
+                if ps.ofm_tensor in self.sg.output_tensors:
+                    # This Op produces a subgraph output
+                    op.requires_full_ofm = True
+                if ps.ifm_tensor.needs_linear_format:
+                    op.requires_full_ifm = True
+                if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format:
+                    op.requires_full_ifm2 = True
+                if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite:
+                    op.requires_full_ofm = True
+                if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1:
+                    # Op has multiple outputs or consumers - requires full OFM
+                    op.requires_full_ofm = True
+
+                # Check memory requirements if this Op requires any full FeatureMaps
+                op_memory_req = 0
+                if op.requires_full_ifm:
+                    op_memory_req += op.ifm_size_in_bytes()
+                if op.requires_full_ifm2:
+                    op_memory_req += op.ifm2_size_in_bytes()
+                if op.requires_full_ofm:
+                    op_memory_req += op.ofm_size_in_bytes()
+
+                min_memory_req = max(op_memory_req, min_memory_req)
+
+        # Theoretical minimum required memory - used to guide the cascade building
+        self.min_memory_req = min_memory_req
+
+    def create_initial_schedule(self) -> Schedule:
+        """Creates an initial schedule with no cascading or buffering of any kind"""
+        schedule = Schedule(self.sg, "MAX")
+
+        for op in self.sched_ops:
+            cost = op.create_scheduler_info(self.nng, op.ofm.shape)
+            cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth)
+            schedule.cost_map[op] = cost
+
+        return schedule
+
+    def update_op_memory_snapshot(self, schedule: Schedule):
+        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+
+        # Collect live ranges from tensors
+        lr_graph = live_range.LiveRangeGraph()
+        for mem_area, mem_type_set in memories_list:
+            live_range.extract_live_ranges_from_cascaded_passes(
+                self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+            )
+
+        # Populate time-array with memory used by live ranges
+        temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
+        schedule.memory_snapshot = temporal_usage
+
+        # Set the peak memory usage
+        schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
+
+    def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth):
+        query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
+        query.ifm_shape = op.ifm.shape
+        query.ifm_memory_area = op.ifm.mem_area
+        query.ifm_bits = op.ifm.dtype.size_in_bits()
+        query.ifm_format = op.ifm.format
+        query.ifm2_shape = op.ifm2 and op.ifm2.shape
+        query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
+        query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
+        query.ifm2_format = op.ifm2 and op.ifm2.format
+        query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
+        query.ofm_memory_area = op.ofm.mem_area
+        query.ofm_bits = op.ofm.dtype.size_in_bits()
+        query.ofm_format = op.ofm.format
+        if op.parent_op.bias:
+            query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
+            query.const_memory_area = self.arch.fast_storage_mem_area
+
+        query.kernel = op.kernel
+        query.config = block_config
+
+        return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
+
+    def propose_schedule_buffering(self, ref_schedule: Schedule):
+        """Create a buffered schedule"""
+        buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
+        staging_limit_bytes = self.scheduler_options.optimization_sram_limit
+
+        prev_op = None
+        for sched_op in self.sched_ops:
+            if sched_op not in ref_schedule.cost_map:
+                # sched_op is not part of this sub-schedule - skip
+                continue
+
+            self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes)
+            prev_op = sched_op
+
+        return buffered_schedule
+
+    def propose_operator_buffering(
+        self,
+        sched_op: SchedulerOperation,
+        prev_op: SchedulerOperation,
+        buffered_schedule: Schedule,
+        ref_schedule: Schedule,
+        staging_limit_bytes,
+    ):
+        # Mild recursion might mean this Op has already been seen
+        if sched_op in buffered_schedule.cost_map:
+            return
+
+        # Take the reference schedule as default costings for this schedule
+        ref_cost = ref_schedule.cost_map[sched_op]
+        cost = copy.copy(ref_cost)
+        cost.slack_buffering_cycles = ref_cost.cycles.op_cycles
+        memory_snapshot = ref_schedule.memory_snapshot
+        ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0
+        cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage
+        buffered_schedule.cost_map[sched_op] = cost
+
+        # Attempt weight buffering on anything with a weights tensor
+        if sched_op.parent_op.weights:
+            self.propose_weight_buffering(
+                sched_op.parent_op.weights,
+                sched_op.parent_op.bias,
+                sched_op,
+                prev_op,
+                buffered_schedule,
+                ref_schedule,
+                cost.slack_buffering_memory,
+            )
+
+        return cost
+
+    def weights_needs_dma(self, weight_tensor):
+        if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+            # Weights are in permanent storage
+            # Only when permanent storage differs from feature map storage, there is a point moving the data
+            if (
+                weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
+                and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area
+            ):
+                return True
+        return False
+
+    def propose_weight_buffering(
+        self,
+        weight_tensor,
+        scale_tensor,
+        sched_op: SchedulerOperation,
+        prev_op: SchedulerOperation,
+        buffered_schedule: Schedule,
+        ref_schedule: Schedule,
+        buffer_limit_bytes,
+    ):
+        cost = buffered_schedule.cost_map[sched_op]
+        prev_cost = buffered_schedule.cost_map.get(prev_op)
+        ref_cost = ref_schedule.cost_map[sched_op]
+        assert cost and ref_cost
+
+        needs_dma = self.weights_needs_dma(weight_tensor)
+
+        ofm_full_depth_slices = [0, ref_cost.stripe.depth]
+
+        # Encode weights for the full depth
+        full_weights = weight_compressor.encode_weight_and_scale_tensor(
+            self.arch,
+            sched_op.parent_op,
+            weight_tensor,
+            scale_tensor,
+            sched_op.kernel,
+            cost.block_config,
+            ofm_full_depth_slices,
+        )
+        full_weights_bytes = len(full_weights.buffer)
+        cost.ofm_depth_slices = ofm_full_depth_slices
+
+        # No buffering required - take all the weights from permanent storage
+        if sched_op.op_type == Op.FullyConnected or not needs_dma:
+            cost.npu_weights_tensor = full_weights
+            return
+
+        encoded_weights = full_weights
+
+        # How many NPU cycles are available under the previously executing
+        # operator and SRAM unused for performing buffered DMA transfers
+        slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
+        slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0
+
+        # Force full depth for cascaded Ops
+        if ref_cost.cascade != 0:
+            weight_tensor_purpose = TensorSubPurpose.Standard
+            weight_buffer_size = full_weights_bytes
+            # Update the memory snapshot to reflect the added size of the weights
+            ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size
+        else:
+            # Estimate the buffering cycle time for the full set of weights
+            full_transfer_cycles = npu_performance.measure_mem2mem_cycles(
+                self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes
+            )
+            cost.full_weight_transfer_cycles = full_transfer_cycles
+
+            # Calculate the amount of prebuffering necessary (or what is possible with limited
+            # double buffer buffer size)
+            half_buffer_limit = buffer_limit_bytes // 2
+            if full_transfer_cycles > slack_cycles:
+                prebuffer_ratio = slack_cycles / full_transfer_cycles
+                prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
+            else:
+                prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
+                prebuffer_ratio = prebuffer_bytes / full_weights_bytes
+
+            # Have to split the weights if the initial buffering can't store
+            # all of the compressed weights
+            if prebuffer_bytes < full_weights_bytes:
+                prebuffer_depth = int(ref_cost.stripe.depth * prebuffer_ratio)
+
+                # Round prebuffering down to nearest valid split depth
+                prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
+
+                while True:
+                    buffering_depth = max(cost.block_config.ofm_block.depth, prebuffer_depth)
+
+                    # Clamp buffering to the double buffering limit
+                    buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
+                    if buffering_bytes > half_buffer_limit:
+                        buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
+                        buffering_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
+
+                    # Create list of depth slices
+                    depth_slices = [0]
+                    if prebuffer_depth < ref_cost.stripe.depth:
+                        depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth))
+                    depth_slices.append(ref_cost.stripe.depth)
+
+                    # Encode weights based depth slices
+                    cost.ofm_depth_slices = depth_slices
+                    encoded_weights = weight_compressor.encode_weight_and_scale_tensor(
+                        self.arch,
+                        sched_op.parent_op,
+                        weight_tensor,
+                        scale_tensor,
+                        sched_op.kernel,
+                        cost.block_config,
+                        cost.ofm_depth_slices,
+                    )
+
+                    # Chosen buffering might not fit at all, iterate until it does
+                    # or until the minimum usable slice size is reached
+                    if (
+                        encoded_weights.max_range_bytes <= half_buffer_limit
+                        or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
+                    ):
+                        break
+
+                    prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
+
+                # Calculate cycles required to run the last op for use as future slack
+                tail_cycles = self.estimate_op_performance(
+                    sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2]
+                )
+                cost.slack_buffering_cycles = tail_cycles.op_cycles
+
+        # Determine whether the weights need to be double buffered
+        weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes)
+
+        # Only buffer weights if there's still space left for the buffer
+        if weight_buffer_size <= buffer_limit_bytes:
+            assert weight_buffer_size % 16 == 0
+            # Determine whether to double buffer or single buffer
+            if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
+                weight_buffer_size = weight_buffer_size * 2
+                weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
+            else:
+                weight_tensor_purpose = TensorSubPurpose.Standard
+
+            cost.buffered_weight_tensor = Tensor(
+                [1, 1, 1, weight_buffer_size], DataType.uint8, weight_tensor.name + "_buffer"
+            )
+            cost.buffered_weight_tensor.src_tensor = encoded_weights
+            cost.buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
+            cost.buffered_weight_tensor.mem_type = MemType.Scratch_fast
+            cost.buffered_weight_tensor.purpose = TensorPurpose.Weights
+            cost.buffered_weight_tensor.sub_purpose = weight_tensor_purpose
+            if ref_cost.cascade == 0:
+                # Determine if the lifetime can be extended and pre-buffer weights under the previous operation
+                cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory
+
+            cost.slack_buffering_memory -= weight_buffer_size
+        else:
+            # Don't slice or buffer - use the whole depth from persistent storage
+            cost.ofm_depth_slices = ofm_full_depth_slices
+            encoded_weights = full_weights
+
+        cost.npu_weights_tensor = encoded_weights
+
+    def propose_minimal_schedule(self) -> Schedule:
+        """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the
+        next operators stride"""
+        min_schedule = Schedule(self.sg, "MIN")
+        cost_map = min_schedule.cost_map
+
+        # Keep track of the previous Op - which consumes the current Op's OFM
+        prev_op = None
+        for sched_op in reversed(self.sched_ops):
+            min_stripe_height = prev_op.kernel.stride.y if prev_op else 1
+            min_stripe = sched_op.ofm.shape.with_height(min_stripe_height)
+
+            cost = sched_op.create_scheduler_info(self.nng, min_stripe)
+            cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
+            cost_map[sched_op] = cost
+
+            prev_op = sched_op
+
+        return min_schedule
+
+    def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule:
+        """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down"""
+        ref_cost = ref_schedule.cost_map
+
+        striped_schedule = Schedule(self.sg, label)
+        stripe = final_stripe
+        for sched_op in reversed(self.sched_ops):
+            if sched_op not in ref_cost:
+                # sched_op is not part of the sub-schedule - skip
+                continue
+
+            # Create a cost entry with the new stripe
+            cost = sched_op.create_scheduler_info(self.nng, stripe)
+
+            # Copy the weight buffering from the reference schedule
+            cost.buffered_weight_tensor = ref_cost[sched_op].buffered_weight_tensor
+
+            # Estimate performance
+            cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
+            striped_schedule.cost_map[sched_op] = cost
+
+            # Calculate the preceeding Op's stripe
+            stripe = sched_op.ifm.shape.with_height(stripe.height * sched_op.kernel.stride.y)
+
+        return striped_schedule
+
+    def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict):
+        """Estimates the memory usage of a schedule"""
+        cost = schedule.cost_map
+        cascades = schedule.cascades
+        peak_mem_usage = 0
+        for sched_op in self.sched_ops:
+            if sched_op not in cost:
+                # sched_op is not part of the sub-schedule - skip
+                continue
+
+            if cost[sched_op].cascade:
+                # This Op is part of a cascade - use the cascade's memory usage
+                cascade_info = cascades[cost[sched_op].cascade]
+                # Non-local memory usage is already included in the cascade_info
+                peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
+            else:
+                # This Op is not part of a cascade - calculate the memory usage
+                op_weight_buffer = 0
+                if cost[sched_op].buffered_weight_tensor:
+                    op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size()
+
+                op_mem_usage = (
+                    sched_op.ifm_size_in_bytes()
+                    + sched_op.ofm_size_in_bytes()
+                    + op_weight_buffer
+                    + non_local_mem_usage.get(sched_op, 0)
+                )
+                peak_mem_usage = max(op_mem_usage, peak_mem_usage)
+
+        return peak_mem_usage
+
+    def optimize_sub_schedule(
+        self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
+    ) -> Schedule:
+        """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by
+        proposing weight buffering and then continously proposing new stripe sizes"""
+        ref_cost = ref_schedule.cost_map
+        # Extract the ops that are part of this sub-schedule
+        start = cascade_info.start
+        end = cascade_info.end
+        sub_schedule_ops = self.sched_ops[start : end + 1]
+        # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule
+        sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}")
+        for sched_op in sub_schedule_ops:
+            sub_schedule.cost_map[sched_op] = ref_cost[sched_op]
+
+        sub_schedule.cascades[end] = cascade_info
+        # Use the memory snapshot from the reference schedule
+        sub_schedule.memory_snapshot = ref_schedule.memory_snapshot
+
+        # Calculate memory usage that is live during the sub-schedule but not part of it
+        time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
+        mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
+        # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
+        # included in a cascade or not
+        persistent_initial_ifm = (
+            sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
+        )
+        # Calculate non-local-mem-usage per Operator
+        non_local_mem_usage = {}
+        for idx, sched_op in enumerate(sub_schedule_ops):
+            non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule
+            if idx != 0:
+                non_local_mem_usage[sched_op] += persistent_initial_ifm
+
+        cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
+
+        # Start by adding buffering
+        buffered_sub_schedule = self.propose_schedule_buffering(sub_schedule)
+        # Copy the cascades over from the unbuffered-schedule
+        buffered_sub_schedule.cascades = sub_schedule.cascades
+
+        # Generate the possible stripings for the final Op in the sub-schedule
+        final_ofm_shape = sub_schedule_ops[-1].ofm.shape
+        possible_stripes = [
+            final_ofm_shape.with_height(stripe_h) for stripe_h in range(1, final_ofm_shape.height // 2 + 1)
+        ]
+
+        # Propose different striping - the possible stripes are proposed similarly to a binary search
+        best_schedule = buffered_sub_schedule
+        iteration = 0
+        while len(possible_stripes) > 1:
+            proposed_stripe = possible_stripes[len(possible_stripes) // 2]
+            proposed_schedule = self.propose_schedule_striping(
+                proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule
+            )
+
+            cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
+
+            # Check if proposal fits
+            proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
+            if (proposed_schedule_mem_usage) <= memory_limit:
+                # Remove all possible stripes smaller than this
+                possible_stripes = possible_stripes[len(possible_stripes) // 2 :]
+                best_schedule = proposed_schedule
+                if not proposed_schedule.cascades:
+                    # No cascading required - early exit
+                    break
+            else:
+                # Proposal doesn't fit within the limit - remove all possible stripes larger than this
+                possible_stripes = possible_stripes[: len(possible_stripes) // 2]
+
+            iteration += 1
+
+        return best_schedule
+
+    def optimize_schedule(
+        self, schedule: Schedule, max_sched: Schedule, max_template: Schedule, options: SchedulerOptions,
+    ) -> Schedule:
+        """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
+        sram_limit = options.optimization_sram_limit
+        if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
+            # Maximum performance schedule fits within the SRAM target
+            return max_sched
+
+        # Extract the cascades
+        cascades = [cascade for cascade in schedule.cascades.values()]
+        for cascade_info in cascades:
+            # Remove existing cascade from schedule
+            del schedule.cascades[cascade_info.end]
+            # Optimize the sub-schedule in this cascade
+            opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
+            # Update the sub-schedule Op and cascade costs to the full schedule
+            schedule.cost_map.update(opt_sub_schedule.cost_map)
+            schedule.cascades.update(opt_sub_schedule.cascades)
+
+        # Update memory snapshot
+        self.sg.schedule = schedule
+        self.update_op_memory_snapshot(schedule)
+        # Propose schedule buffering to the optimized schedule
+        optimized_sched = self.propose_schedule_buffering(schedule)
+        # Copy the cascade's metadata from the unbuffered schedule
+        optimized_sched.cascades = schedule.cascades
+        return optimized_sched
+
+    def apply_schedule(self, sched: Schedule):
+        """Applies the given schedule as a final solution"""
+        for sched_op in self.sched_ops:
+            op_info = sched.cost_map[sched_op]
+            cascade_info = sched.cascades.get(op_info.cascade, None)
+            if cascade_info and sched_op in cascade_info.buffers:
+                buffer_tens = sched_op.ifm.connection.parent_tens
+                # Apply memory area and type
+                buffer_tens.mem_area = self.arch.fast_storage_mem_area
+                buffer_tens.mem_type = MemType.Scratch_fast
+                # Apply Rolling buffer
+                buffer_tens.set_format(TensorFormat.NHCWB16, self.arch)
+                buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height)
+
+            sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
+
+            # Ensure that the src_tensor reference is set correctly
+            if op_info.buffered_weight_tensor:
+                op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
+
+    def use_fast_storage_for_feature_maps(self, schedule: Schedule, memory_limit: int):
+        if self.arch.fast_storage_mem_area == self.arch.feature_map_storage_mem_area:
+            return
+
+        # Force all OFMs to fast-storage
+        for sched_op in self.sched_ops:
+            cost = schedule.cost_map[sched_op]
+            if cost.cascade == 0:
+                if sched_op.get_dependants():
+                    ofm_tens = sched_op.ofm.connection.parent_tens
+                    if not any(cons is None for cons in ofm_tens.consumer_list):
+                        ofm_tens.mem_area = self.arch.fast_storage_mem_area
+                        ofm_tens.mem_type = MemType.Scratch_fast
+
+        # Collect live ranges from tensors
+        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+        lr_graph = live_range.LiveRangeGraph()
+        for mem_area, mem_type_set in memories_list:
+            live_range.extract_live_ranges_from_cascaded_passes(
+                self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+            )
+
+        # Iterate over live ranges and evict tensors that doesn't fit
+        fast_storage_snapshot = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
+        for lr in lr_graph.lrs:
+            if (
+                lr.mem_area == self.arch.fast_storage_mem_area
+                and max(fast_storage_snapshot[lr.start_time : lr.end_time + 1]) > memory_limit
+            ):
+                # Evict tensor to DRAM
+                for tens in lr.tensors:
+                    if tens.purpose == TensorPurpose.FeatureMap and tens.sub_purpose == TensorSubPurpose.Standard:
+                        # Can only evict unbuffered FeatureMaps
+                        tens.mem_area = self.arch.feature_map_storage_mem_area
+                        tens.mem_type = MemType.Scratch
+                        # Adjust the snapshot
+                        fast_storage_snapshot[lr.start_time : lr.end_time + 1] -= lr.size
+
+    def move_constant_data(self):
+        """Determine if  data, can be moved from permanent storage to another memory area. A move
+        will generate a DMA command in the high-level command stream"""
+        for sched_op in self.sched_ops:
+            parent_op = sched_op.parent_op
+            is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs)
+            max_ifm_shram_avail = (
+                (self.arch.available_shram_banks(is_lut_used) - self.arch.shram_reserved_output_banks)
+                * self.arch.shram_bank_size
+                // 2
+            )
+
+            for idx, tens in enumerate(parent_op.inputs):
+                if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+                    # Tensor is in permanent storage
+                    # Only when permanent storage differs from feature map storage, there is a point moving the data
+                    if (
+                        tens.mem_area in self.arch.permanent_storage_mem_area
+                        and self.arch.permanent_storage_mem_area != self.arch.feature_map_storage_mem_area
+                    ) or tens.purpose == TensorPurpose.LUT:
+                        if tens.purpose == TensorPurpose.LUT or (
+                            tens.purpose == TensorPurpose.FeatureMap
+                            and sched_op.op_type.is_binary_elementwise_op()
+                            and tens.shape != []
+                            and sched_op.ifm.shape != sched_op.ofm.shape
+                            and tens.storage_size() > max_ifm_shram_avail
+                        ):
+                            only_vector_product_consumers = all(
+                                oper and oper.type.npu_block_type == NpuBlockType.VectorProduct
+                                for oper in tens.consumers()
+                            )
+
+                            if (not only_vector_product_consumers) or tens.purpose == TensorPurpose.LUT:
+                                new_tens = tens.clone_into_fast_storage(self.arch)
+                                if tens.purpose == TensorPurpose.LUT:
+                                    new_tens.mem_area = MemArea.Shram
+
+                                new_tens.consumer_list.append(parent_op)
+                                parent_op.inputs[idx] = new_tens
+                                sched_op.parent_ps.inputs[idx] = new_tens
+
+    def print_schedule(self, schedule: Schedule):
+        print(f"Schedule: '{schedule.name}'")
+        for sched_op in self.sched_ops:
+            if sched_op not in schedule.cost_map:
+                # Sub-schedule printing
+                continue
+
+            op_info = schedule.cost_map[sched_op]
+            print(f"\t{sched_op.index}: Operation {sched_op.name}  - OFM {sched_op.ofm.shape}")
+            print(f"\t\tType: {sched_op.op_type}")
+            print(f"\t\tKernel: {sched_op.kernel}")
+            print(f"{op_info}")
+            mem_usage = (
+                schedule.memory_snapshot[op_info.time_index]
+                if op_info.time_index < len(schedule.memory_snapshot)
+                else 0
+            )
+            print(f"\t\tSRAM Used: {mem_usage} bytes")
+
+        print(f"\tCascades:")
+        for i, cascade in enumerate(schedule.cascades.values()):
+            print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}")
 
 
-def move_scales_to_fast_storage(nng, arch):
+def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options):
+    """
+    Creates live ranges and runs tensor allocator for the current schedule
+    (i.e. sg.schedule for all subgraphs), returns the maximum memory usage
+    and updates SchedulerOpInfo.mem_usage for all operations in the schedule.
+    """
+    root_sg = nng.get_root_subgraph()
+
+    alloc_list = []
+    if arch.is_spilling_enabled():
+        mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
+        # Order is important
+        alloc_list.append(mem_alloc_scratch_fast)
+        alloc_list.append(mem_alloc_scratch)
+    else:
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
+        alloc_list.append(mem_alloc_scratch)
+
+    for mem_area, mem_type_set in alloc_list:
+        tensor_allocation.allocate_tensors(
+            nng,
+            root_sg,
+            arch,
+            mem_area,
+            mem_type_set,
+            tensor_allocator=options.tensor_allocator,
+            verbose_allocation=options.verbose_allocation,
+            cpu_tensor_alignment=options.cpu_tensor_alignment,
+        )
+
+
+def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
+    """Entry point for the Scheduler"""
+    # Initialize CPU subgraphs
+    schedulers = dict()
+    # Initialize schedulers with max schedule. Only schedule NPU subgraphs
     for sg in nng.subgraphs:
-        # IFM streamed ops reads bias tensors several times, move these to fast storage
-        for cp in sg.cascaded_passes:
-            if cp.strategy == SchedulingStrategy.IfmStream:
-                # Calculate SRAM usage
-                new_size = 0
-                all_tens = []
-                for ps in cp.passes:
-                    pass_tens = np.array([ps.ifm_tensor, ps.ifm2_tensor, ps.ofm_tensor, ps.weight_tensor])
-                    pass_tens = np.append(pass_tens, ps.intermediates)
-                    for tens in pass_tens:
-                        if tens and tens.mem_area == MemArea.Sram and tens not in all_tens:
-                            all_tens.append(tens)
-                            new_size += tens.storage_size()
+        if sg.placement != PassPlacement.Npu:
+            # Create cascaded passes for CPU Ops
+            cascaded_passes = []
+            for idx, ps in enumerate(sg.passes):
+                cps = CascadedPass(
+                    ps.name, SchedulingStrategy.WeightStream, ps.inputs, [], ps.outputs, [ps], ps.placement, False,
+                )
 
-                cp.sram_used = new_size
+                cps.time = idx
+                ps.cascade = cps
+                cascaded_passes.append(cps)
 
-                for ps in cp.passes:
-                    if ps.scale_tensor:
-                        tens = ps.scale_tensor
+            sg.cascaded_passes = cascaded_passes
+        else:
+            # Npu subgraph - create schedule
+            scheduler = Scheduler(nng, sg, arch, scheduler_options)
+            schedulers[sg] = scheduler
 
-                        # Find op using scale tensor
-                        op = next((op for op in ps.ops if tens in op.inputs), None)
-                        assert op
+            scheduler.create_scheduler_representation(arch)
+            sg.sched_ops = scheduler.sched_ops
+            scheduler.move_constant_data()
 
-                        # Create fast storage tensor
-                        new_tens = tens.clone_into_fast_storage(arch)
-                        new_tens.consumer_list = tens.consumer_list.copy()
-                        new_tens.purpose = TensorPurpose.FSBias
-                        new_tens_size = new_tens.storage_size()
+            # Create the Max schedule template
+            max_schedule_template = scheduler.create_initial_schedule()
+            scheduler.max_schedule = max_schedule_template
 
-                        if (cp.sram_used + new_tens_size) <= arch.sram_size:
-                            # Create DMA cmd
-                            dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
-                            dma_cmd.inputs = [tens]
-                            dma_cmd.set_output_tensor(new_tens)
-                            dma_cmd.attrs["source"] = tens.mem_area
-                            dma_cmd.attrs["destination"] = new_tens.mem_area
-                            dma_cmd.run_on_npu = True
+            # Create the optimimised Max schedule
+            sg.schedule = max_schedule_template
+            scheduler.update_op_memory_snapshot(max_schedule_template)
+            opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template)
+            sg.schedule = opt_max_schedule
+            scheduler.update_op_memory_snapshot(opt_max_schedule)
 
-                            tens.consumer_list.clear()
-                            tens.consumer_list.append(dma_cmd)
+            # Create Min schedule
+            min_schedule = scheduler.propose_minimal_schedule()
+            initial_sram_limit = scheduler_options.optimization_sram_limit
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
+                initial_sram_limit = scheduler.min_memory_req
 
-                            # Replace tensor and op
-                            idx = op.inputs.index(tens)
-                            op.inputs[idx] = new_tens
+            cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled())
+            cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit)
+            sg.schedule = min_schedule
+            scheduler.update_op_memory_snapshot(min_schedule)
 
-                            ps.ops.insert(0, dma_cmd)
-                            ps.scale_tensor = new_tens
-                            ps.intermediates.append(new_tens)
-                            ps.cascade.intermediates.append(new_tens)
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
+                # Create an optimized schedule
+                sg.schedule = scheduler.optimize_schedule(
+                    min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
+                )
+                scheduler.update_op_memory_snapshot(sg.schedule)
 
-                            cp.sram_used += new_tens_size
+            scheduler.apply_schedule(sg.schedule)
+            scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
 
+            if scheduler_options.verbose_schedule:
+                scheduler.print_schedule(sg.schedule)
 
-def schedule_passes(nng, arch, options: SchedulerOptions):
-
-    for sg in nng.subgraphs:
-        sg.base_sram_used = 0
-
-    for sg in nng.subgraphs:
-        # re-entering the same nodes from different contexts requires us to
-        # build a simplified directed acyclic (DAG) version of the graph to
-        # use for traversal, rather than using a visit dictionary. this avoids
-        # recursing infinitely due to loops.
-        sg.build_pass_dag_predecessors()
-
-        dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
-
-        strat_set = dps.search()
-
-        dps.apply_result(strat_set, arch)
-
-        if options.verbose_schedule:
-            sg.print_cascaded_passes()
-
-
-def _calc_tens_to_cps(sg, tensor_rewrites):
-    # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
-    # Returns dictionary tensor -> list of cascaded passes
-    # Note: if cascaded passes are A, B, C, D, and a tensor is output
-    # of A and input to D, then it also consumes SRAM in passes B and C.
-    if "tens_to_cps" in sg.scheduling_info:
-        return sg.scheduling_info["tens_to_cps"]
-    # Determine life-time of tensors
-    min_index = {}
-    max_index = {}
-    index = 0
-    cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
-    for cps in cps_list:
-        for tens in cps.inputs + cps.outputs:
-            if tens in tensor_rewrites:
-                min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
-                max_index[tens] = index
-        index += 1
-    # Convert to affected cps-es
-    tens_to_cps = {}
-    for tens in min_index:
-        tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
-    sg.scheduling_info["tens_to_cps"] = tens_to_cps
-    return tens_to_cps
-
-
-def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
-    # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
-    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
-    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
-    # Sort tensors first on life-time (smallest first), then on size (biggest first)
-    tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
-    for _, _, _, tens in tens_list:
-        cps_list = tens_to_cps[tens]
-        if len(cps_list) < 1:
-            continue
-        sz = tens.storage_size()
-        fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
-        if fits_in_fast_storage:
-            tens.mem_area = arch.fast_storage_mem_area
-            tens.mem_type = MemType.Scratch_fast
-            tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
-            assert tens in tensor_rewrites
-            # Also rewrite reshapes
-            for rewrite_op in tensor_rewrites[tens]:
-                tens2 = rewrite_op.outputs[0]
-                tens2.mem_area = arch.fast_storage_mem_area
-                tens2.mem_type = MemType.Scratch_fast
-                tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
-            for cps in cps_list:
-                cps.sram_used += sz
-
-
-def undo_use_fast_storage(sg, arch):
-    # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
-    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
-    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
-    mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
-    for tens, cps_list in tens_to_cps.items():
-        if tens.mem_type == MemType.Scratch_fast:
-            sz = tens.storage_size()
-            tens.mem_area = mem_area
-            tens.mem_type = MemType.Scratch
-            # Also undo reshapes
-            for rewrite_op in tensor_rewrites[tens]:
-                tens2 = rewrite_op.outputs[0]
-                tens2.mem_area = mem_area
-                tens2.mem_type = MemType.Scratch
-            for cps in cps_list:
-                cps.sram_used -= sz
+    # Evaluate schedule
+    _update_tensor_allocation(nng, arch, options)
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 5d849d9..fd67403 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -16,8 +16,10 @@
 # Description:
 # Defines the class Shape4D.
 from collections import namedtuple
+from enum import Enum
 
 from .numeric_util import full_shape
+from .numeric_util import round_up
 from .numeric_util import round_up_divide
 
 
@@ -42,6 +44,27 @@
         return cls(tmp[0], tmp[1], tmp[2], tmp[3])
 
     @classmethod
+    def min(cls, lhs, rhs):
+        return Shape4D(
+            min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth)
+        )
+
+    @classmethod
+    def max(cls, lhs, rhs):
+        return Shape4D(
+            max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth)
+        )
+
+    @classmethod
+    def round_up(cls, lhs, rhs):
+        return Shape4D(
+            round_up(lhs.batch, rhs.batch),
+            round_up(lhs.height, rhs.height),
+            round_up(lhs.width, rhs.width),
+            round_up(lhs.depth, rhs.depth),
+        )
+
+    @classmethod
     def from_hwc(cls, h, w, c):
         return cls(1, h, w, c)
 
@@ -60,6 +83,25 @@
     def with_depth(self, new_depth):
         return Shape4D(self.batch, self.height, self.width, new_depth)
 
+    def with_axis(self, axis, new_val):
+        shape_as_list = self.as_list()
+        shape_as_list[axis] = new_val
+        return Shape4D.from_list(shape_as_list)
+
+    @staticmethod
+    def _clip_len(pos, length, size):
+        if pos < 0:
+            length = length + pos
+            pos = 0
+        return min(pos + length, size) - pos
+
+    def clip(self, offset, sub_shape):
+        n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch)
+        h = Shape4D._clip_len(offset.height, sub_shape.height, self.height)
+        w = Shape4D._clip_len(offset.width, sub_shape.width, self.width)
+        c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth)
+        return Shape4D(n, h, w, c)
+
     def add(self, n, h, w, c):
         return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
 
@@ -74,6 +116,9 @@
             self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
         )
 
+    def __truediv__(self, rhs):
+        return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth)
+
     def __mod__(self, rhs):
         return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
 
@@ -102,3 +147,52 @@
 
     def get_hw_as_list(self):
         return list([self.height, self.width])
+
+
+class VolumeIterator:
+    """
+    4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes.
+    """
+
+    class Direction(Enum):
+        CWHN = 0
+
+    def __init__(
+        self,
+        shape: Shape4D,
+        sub_shape: Shape4D,
+        start: Shape4D = Shape4D(0, 0, 0, 0),
+        delta: Shape4D = None,
+        dir=Direction.CWHN,
+    ):
+        self.b = start.batch
+        self.y = start.height
+        self.x = start.width
+        self.z = start.depth
+        self.shape = shape
+        self.sub_shape = sub_shape
+        self.delta = sub_shape if delta is None else delta
+        assert self.delta.elements() > 0, "Iterator will not move"
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.b >= self.shape.batch:
+            raise StopIteration()
+
+        offset = Shape4D(self.b, self.y, self.x, self.z)
+
+        # CWHN
+        self.z += self.delta.depth
+        if self.z >= self.shape.depth:
+            self.z = 0
+            self.x += self.delta.width
+            if self.x >= self.shape.width:
+                self.x = 0
+                self.y += self.delta.height
+                if self.y >= self.shape.height:
+                    self.y = 0
+                    self.b += self.delta.batch
+
+        return offset, self.shape.clip(offset, self.sub_shape)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index ea4aaf0..c9a97c0 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -26,7 +26,6 @@
 from .architecture_features import Block
 from .architecture_features import SharedBufferArea
 from .architecture_features import SHRAMElements
-from .errors import AllocationError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .operation import Kernel
 from .operation import NpuBlockType
@@ -262,11 +261,6 @@
 
 def find_suitable_block_configs(arch, alloc: SharedBufferAllocation) -> List[Tuple]:
     """Returns list of block configs that would fit with the given shared buffer allocation"""
-    if arch.override_block_config:
-        config = alloc.try_block(arch.override_block_config)
-        if config is None:
-            raise AllocationError(f"Block config override '{arch.override_block_config}' cannot be allocated")
-        return [config]
 
     # Constrain the search space if the OFM is smaller than the max block size
     # - Add other block search constraints here if required
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index fbc47f8..32e4fd5 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -45,7 +45,7 @@
         ]
 
         labels += (
-            ["accelerator_configuration", "system_config", "memory_mode", "core_clock", "sram_size"]
+            ["accelerator_configuration", "system_config", "memory_mode", "core_clock", "arena_cache_size"]
             + [area.identifier_name() + "_bandwidth" for area in mem_areas]
             + ["weights_storage_area", "feature_map_storage_area"]
         )
@@ -89,7 +89,7 @@
                     arch.system_config,
                     arch.memory_mode,
                     arch.core_clock,
-                    arch.sram_size / 1024,
+                    arch.arena_cache_size / 1024,
                 ]
                 + [arch.memory_bandwidths_per_second[mem_area] / 1000.0 / 1000 / 1000 for mem_area in mem_areas]
                 + [
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 15bd05e..7dbdcdd 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -345,6 +345,7 @@
         "dtype",
         "name",
         "is_variable",
+        "pre_buffer",
         "ops",
         "consumer_list",
         "values",
@@ -372,6 +373,7 @@
         "block_traversal",
         "equivalence_id",
         "resampling_mode",
+        "src_tensor",
         "needs_linear_format",
     )
     AllocationQuantum = 16
@@ -383,6 +385,7 @@
         self.dtype = dtype
         self.name = name
         self.is_variable = False
+        self.pre_buffer = False
         self.equivalence_id: UUID = uuid.uuid4()
 
         self.ops: List[Operation] = []
@@ -420,6 +423,9 @@
 
         self.needs_linear_format = True
 
+        # Reference to parent-tensor if this tensor is a clone
+        self.src_tensor = None
+
     @property
     def address(self) -> int:
         return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
@@ -460,6 +466,7 @@
         res = self.clone(suffix="_fast_storage")
         res.mem_area = arch.fast_storage_mem_area
         res.mem_type = MemType.Scratch_fast
+        res.src_tensor = self
         return res
 
     def copy_compressed_weight_info(self, src_tens: "Tensor"):
@@ -536,31 +543,6 @@
         rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
         return rounded_size
 
-    def storage_size_for_sub_purpose(
-        self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None
-    ) -> int:
-        alt_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
-        elems = shape_num_elements(alt_shape)
-        if elems is None:
-            return 0
-        if sub_purpose == TensorSubPurpose.DoubleBuffer:
-            raw_size = (
-                elems
-                * self.element_size()
-                * self.compression_scale_for_worst_weight_stream
-                * arch.weight_estimation_scaling
-            )
-        else:
-            # Rolling buffers are used for intermediate data in ifm streaming
-            # These will all use the NHCWB16 format, and need to be aligned to 16 in the C-dimension
-            if alt_shape[-1] % 16 != 0:
-                nhcwb16_shape = alt_shape[0:-1] + [numeric_util.round_up(alt_shape[-1], 16)]
-                elems = shape_num_elements(nhcwb16_shape)
-
-            raw_size = elems * self.element_size() * self.storage_compression_scale
-        rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
-        return rounded_size
-
     def storage_shape_for_sub_purpose(
         self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
     ) -> Shape:
@@ -724,19 +706,9 @@
         assert strides is not None
         return strides
 
-    def needs_dma(self) -> bool:
-        return len(self.ops) == 1 and self.ops[0].type == Op.DMA
-
-    def get_dma_src_tensor(self) -> "Optional[Tensor]":
-        # For weight tensors that need DMA: returns the source tensor in Flash, else None
-        # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
-        return self.ops[0].inputs[0] if self.needs_dma() else None
-
     def find_npu_op(self) -> Optional[Operation]:
-        # Returns the NPU operator that uses this tensor, excluding DMA operators.
+        # Returns the NPU operator that uses this tensor
         for op in self.consumers():
-            if op.type == Op.DMA:
-                return op.outputs[0].find_npu_op()
             if op.run_on_npu:
                 return op
         return None
@@ -779,6 +751,7 @@
         self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
     ) -> Optional[int]:
         address_offset = 0
+        assert self.purpose != TensorPurpose.Weights
 
         if self.sub_purpose == TensorSubPurpose.Standard:
             shape = op_shape4D.as_list() if op_shape4D else self.shape
@@ -787,63 +760,29 @@
                     assert c > 0 and c <= shape[idx]
                 else:
                     assert c >= 0 and c < shape[idx]
-
-        if self.format == TensorFormat.WeightsCompressed:
-            storage_size = self.storage_size()
-            if len(self.weight_compressed_offsets) == 0:
-                return 0
-
-            if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
-                depth = orig_coord[-1]
-                brick_depth = self.brick_size[-1]
-                # Clamp position at final element index
-                if depth > self.shape[-1]:
-                    depth = self.shape[-1]
-
-                # Always round up to next boundary
-                index = numeric_util.round_up_divide(depth, brick_depth)
-                index = index % 2
-                assert self.compressed_values is not None
-
-                if len(self.compressed_values) <= 2:
-                    if is_top_box and index == 0:
-                        for cv in self.compressed_values:
-                            address_offset += len(cv)
-                    else:
-                        address_offset = index * len(self.compressed_values[0])
-                else:
-                    if is_top_box and index == 0:
-                        address_offset = self.storage_shape[-1]
-                    else:
-                        address_offset = index * (self.storage_shape[-1] // 2)
-            else:
-                index = self.compressed_stream_index_from_coord(orig_coord)
-                assert index < len(self.weight_compressed_offsets)
-                address_offset = self.weight_compressed_offsets[index]
+        coord = orig_coord
+        if op_shape4D and self.is_standard_fm:
+            storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
+            storage_size = self.storage_size_for_shape(storage_shape)
         else:
-            coord = orig_coord
-            if op_shape4D and self.is_standard_fm:
-                storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
-                storage_size = self.storage_size_for_shape(storage_shape)
-            else:
-                storage_shape = self.storage_shape
-                coord = coord[-len(storage_shape) :]
-                storage_size = self.storage_size()
+            storage_shape = self.storage_shape
+            coord = coord[-len(storage_shape) :]
+            storage_size = self.storage_size()
 
-            if is_top_box:
-                coord = [c - 1 for c in coord]
+        if is_top_box:
+            coord = [c - 1 for c in coord]
 
-            # handle wraparound for partial buffers. make sure to do this after subtracting top box:
-            coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
+        # handle wraparound for partial buffers. make sure to do this after subtracting top box:
+        coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
 
-            strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
-            if strides is None:
-                return None
+        strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
+        if strides is None:
+            return None
 
-            if is_top_box:
-                address_offset += 1 * strides[-1]  # one element
+        if is_top_box:
+            address_offset += 1 * strides[-1]  # one element
 
-            address_offset += np.dot(augmented_coord, strides)
+        address_offset += np.dot(augmented_coord, strides)
 
         assert address_offset >= 0
         assert address_offset <= storage_size
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 724c7c0..d3e2a03 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -106,6 +106,8 @@
 
 
 def mark_sram_used_for_cascaded_passes(sg, lrs):
+    if len(sg.cascaded_passes) < 1:
+        return
     end_pos = max(ps.time for ps in sg.cascaded_passes) + 2
     mem_usage = np.zeros(end_pos, dtype=np.int64)
 
@@ -169,6 +171,40 @@
     return histogram
 
 
+def allocate(
+    sg,
+    arch,
+    mem_area,
+    mem_type_set,
+    tensor_allocator=TensorAllocator.Greedy,
+    lr_graph=None,
+    cpu_tensor_alignment=Tensor.AllocationQuantum,
+):
+    # Allocates addresses to tensors, returns False if tensors could not be fit within max_size
+    ignore_subgraph_input_output_tensors = False
+    lrs = live_range.extract_live_ranges_from_cascaded_passes(
+        sg,
+        mem_area,
+        mem_type_set,
+        ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
+        lr_graph=lr_graph,
+        cpu_tensor_alignment=cpu_tensor_alignment,
+    )
+    total_sz = 0
+    if lrs.ranges:
+        tens_alloc = tensor_allocator
+        if tens_alloc == TensorAllocator.Greedy:
+            total_sz = greedy_allocate_live_ranges(sg, arch, lrs, mem_area, cpu_tensor_alignment)
+            verify_allocation(lrs, cpu_tensor_alignment)
+        elif tens_alloc == TensorAllocator.LinearAlloc:
+            total_sz = linear_allocate_live_ranges(lrs, cpu_tensor_alignment)
+        elif tens_alloc == TensorAllocator.HillClimb:
+            total_sz = hillclimb_allocate_live_ranges(lrs, cpu_tensor_alignment)
+        else:
+            assert 0
+    return lrs, total_sz
+
+
 def allocate_tensors(
     nng,
     sg,
@@ -183,27 +219,17 @@
     dry_test=False,
 ):
     # Allocates addresses to tensors, returns False if tensors could not be fit within max_size
-    ignore_subgraph_input_output_tensors = False
-    lrs = live_range.extract_live_ranges_from_cascaded_passes(
+    lrs, total_sz = allocate(
         sg,
+        arch,
         mem_area,
         mem_type_set,
-        ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
+        tensor_allocator=tensor_allocator,
         lr_graph=lr_graph,
         cpu_tensor_alignment=cpu_tensor_alignment,
     )
 
     if lrs.ranges:
-        tens_alloc = tensor_allocator
-        if tens_alloc == TensorAllocator.Greedy:
-            total_sz = greedy_allocate_live_ranges(sg, arch, lrs, mem_area, cpu_tensor_alignment)
-            verify_allocation(lrs, cpu_tensor_alignment)
-        elif tens_alloc == TensorAllocator.LinearAlloc:
-            total_sz = linear_allocate_live_ranges(lrs, cpu_tensor_alignment)
-        elif tens_alloc == TensorAllocator.HillClimb:
-            total_sz = hillclimb_allocate_live_ranges(lrs, cpu_tensor_alignment)
-        else:
-            assert 0
         alloc_ok = max_size is None or total_sz <= max_size
         if dry_test or not alloc_ok:
             # Dry test or allocation failed; undo allocation
@@ -233,5 +259,4 @@
 
     if sg == nng.get_root_subgraph():
         nng.memory_used = sg.memory_used
-
     return True
diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
index 3c9a43d..ee13430 100644
--- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py
+++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
@@ -167,11 +167,13 @@
     check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, 15)
     check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, 3)
     check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, 15)
-    check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 14)
-    check_cmd0(cmds, cmd0.NPU_SET_AB_START, 14)
     check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0)
     check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0)
     check_cmd0(cmds, cmd0.NPU_OP_CONV, 0)
+    ib_end = find_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END)
+    ab_start = find_cmd0(cmds, cmd0.NPU_SET_AB_START)
+    assert ib_end > 0
+    assert ib_end <= ab_start
 
 
 def create_fully_connected_op() -> NpuConv2DOperation:
@@ -296,11 +298,13 @@
     check_cmd0(cmds, cmd0.NPU_SET_IFM2_PRECISION, 0)
     check_cmd0(cmds, cmd0.NPU_SET_IFM2_BROADCAST, 5)
     check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 16)
-    check_cmd0(cmds, cmd0.NPU_SET_AB_START, 16)
-    check_cmd0(cmds, cmd0.NPU_SET_IFM2_IB_START, 9)
     check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0)
     check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0)
     check_cmd0(cmds, cmd0.NPU_OP_ELEMENTWISE, 0)
+    ab_start = find_cmd0(cmds, cmd0.NPU_SET_AB_START)
+    assert ab_start > 0
+    ifm2_ib_start = find_cmd0(cmds, cmd0.NPU_SET_IFM2_IB_START)
+    assert 0 < ifm2_ib_start < ab_start
     # Check that block width/height were generated that fit
     blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1)
     blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1)
@@ -413,11 +417,11 @@
     w, h = op.ofm.shape.width, op.ofm.shape.height
     op.ofm.tiles = NpuTileBox(width_0=w, height_0=h, height_1=h, addresses=[32 * 1024, 0, 0, 0])
     # 384K for spilling should fit
-    arch.sram_size = 384 * 1024
+    arch.arena_cache_size = 384 * 1024
     mem_limits = get_mem_limits_for_regions(arch)
     generate_command_stream([op], arch, verbose=False, mem_limits=mem_limits)
     # 32K for spilling does not fit, due to the OFM address
-    arch.sram_size = 32 * 1024
+    arch.arena_cache_size = 32 * 1024
     mem_limits = get_mem_limits_for_regions(arch)
     with pytest.raises(VelaError):
         generate_command_stream([op], arch, verbose=False, mem_limits=mem_limits)
diff --git a/ethosu/vela/test/test_architecture_allocator.py b/ethosu/vela/test/test_architecture_allocator.py
new file mode 100644
index 0000000..94768fc
--- /dev/null
+++ b/ethosu/vela/test/test_architecture_allocator.py
@@ -0,0 +1,123 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for architecture_allocator.py
+import pytest
+
+from ethosu.vela.architecture_allocator import find_block_config
+from ethosu.vela.architecture_allocator import try_block_config
+from ethosu.vela.architecture_features import Accelerator
+from ethosu.vela.architecture_features import Block
+from ethosu.vela.architecture_features import create_default_arch
+from ethosu.vela.ethos_u55_regs.ethos_u55_regs import resampling_mode
+from ethosu.vela.operation import Kernel
+from ethosu.vela.operation import NpuBlockType
+from ethosu.vela.shape4d import Shape4D
+
+test_data = [
+    {
+        "block_type": NpuBlockType.ConvolutionDepthWise,
+        "kernel": Kernel(25, 5, 2, 2, 1, 1),
+        "ofm_shape": Shape4D(2, 11, 22),
+        "ifm_shape": Shape4D(27, 25, 22),
+    },
+    {
+        "block_type": NpuBlockType.Pooling,
+        "kernel": Kernel(2, 2),
+        "ofm_shape": Shape4D(53, 49, 22),
+        "ifm_shape": Shape4D(27, 25, 22),
+        "ifm_resampling": resampling_mode.NEAREST,
+    },
+    {
+        "block_type": NpuBlockType.ConvolutionMxN,
+        "accelerator": Accelerator.Ethos_U55_32,
+        "kernel": Kernel(2, 5),
+        "ofm_shape": Shape4D(48, 1, 17),
+        "ifm_shape": Shape4D(24, 5, 18),
+        "ifm_resampling": resampling_mode.TRANSPOSE,
+    },
+    {
+        "block_type": NpuBlockType.ElementWise,
+        "ofm_shape": Shape4D(27, 2, 22),
+        "ifm_shape": Shape4D(27, 2, 1),
+        "ifm2_shape": Shape4D(27, 25, 22),
+    },
+    {
+        "block_type": NpuBlockType.ElementWise,
+        "accelerator": Accelerator.Ethos_U55_32,
+        "ofm_shape": Shape4D(48, 37, 17),
+        "ifm_shape": Shape4D(48, 37, 17),
+        "uses_scalar": True,
+        "lut_banks": 2,
+    },
+    {
+        "block_type": NpuBlockType.ElementWise,
+        "ofm_shape": Shape4D(27, 2, 22),
+        "ifm_shape": Shape4D(27, 2, 22),
+        "ifm_bits": 16,
+    },
+]
+
+
+@pytest.mark.parametrize("test_data", test_data)
+def test_allocate(test_data):
+    """Tests that find_block_config and try_block_config produce consistent SHRAM layouts"""
+    accelerator = test_data.get("accelerator", Accelerator.Ethos_U55_128)
+    arch = create_default_arch(accelerator)
+    kernel = test_data.get("kernel", Kernel(1, 1))
+    block_type = test_data["block_type"]
+    ofm_shape = test_data["ofm_shape"]
+    ifm_shape = test_data["ifm_shape"]
+    ifm2_shape = test_data.get("ifm2_shape")
+    uses_scalar = test_data.get("uses_scalar", False)
+    ifm_bits = test_data.get("ifm_bits", 8)
+    ifm_resampling = test_data.get("ifm_resampling", resampling_mode.NONE)
+    scaled = test_data.get("scaled", True)
+    lut_banks = test_data.get("lut_banks", 0)
+    config = find_block_config(
+        arch,
+        block_type,
+        ofm_shape,
+        ifm_shape,
+        ifm2_shape,
+        uses_scalar=uses_scalar,
+        ifm_bits=ifm_bits,
+        kernel=kernel,
+        lut_banks=lut_banks,
+        scaled=scaled,
+        ifm_resampling=ifm_resampling,
+    )
+    assert config is not None
+    config2 = try_block_config(
+        Block.from_shape(config.ofm_block.as_list()),
+        arch,
+        block_type,
+        ifm_shape,
+        ifm2_shape,
+        is_partkernel=config.is_partkernel,
+        uses_scalar=uses_scalar,
+        ifm_bits=ifm_bits,
+        kernel=kernel,
+        lut_banks=lut_banks,
+        scaled=scaled,
+        ifm_resampling=ifm_resampling,
+    )
+    assert config2 is not None
+    assert config.layout.ib_end == config2.layout.ib_end
+    assert config.layout.ab_start == config2.layout.ab_start
+    assert config.layout.ib_start2 == config2.layout.ib_start2
+    assert config.acc_type == config2.acc_type
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index 44ee0af..4ddc8b9 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -19,7 +19,6 @@
 
 import numpy as np
 
-from ethosu.vela import insert_dma
 from ethosu.vela import lut
 from ethosu.vela import mark_tensors
 from ethosu.vela import pass_packing
@@ -27,37 +26,41 @@
 from ethosu.vela.high_level_command_stream import DMA
 from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
+from ethosu.vela.rewrite_graph import rewrite_graph_pre_order
 from ethosu.vela.rewrite_graph import verify_graph_health
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import TensorPurpose
 from ethosu.vela.test import testutil
 
 
-def set_256_lut(op, key):
+def set_256_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(256), k=256)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
     )
-    op.set_activation_lut(lut_tensor)
+    scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
+    op.set_activation_lut(scratch_lut_tensor)
 
 
-def set_1K_lut(op, key):
+def set_1K_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(256), k=256)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
     )
-    op.set_activation_lut(lut_tensor)
+    scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
+    op.set_activation_lut(scratch_lut_tensor)
 
 
-def set_2K_lut(op, key):
+def set_2K_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(512), k=512)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
     )
-    op.set_activation_lut(lut_tensor)
+    scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
+    op.set_activation_lut(scratch_lut_tensor)
 
 
 def process(arch, op_list):
@@ -68,16 +71,16 @@
     assert verify_graph_health(nng)
     nng = mark_tensors.mark_tensor_purpose(nng, arch, False)
     assert verify_graph_health(nng)
-    nng = insert_dma.insert_dma_commands(nng, arch, False)
-    assert verify_graph_health(nng)
+    rewrite_graph_pre_order(nng, sg, arch, [], [])
     pass_packing.pack_into_passes(nng, arch, False)
     assert verify_graph_health(nng)
     # Create a DMA instruction for every op
     cmd_list = []
     for ps in sg.passes:
-        for intermediate in ps.intermediates:
-            if intermediate.needs_dma():
-                cmd_list.append(DMA(ps, intermediate.get_dma_src_tensor(), intermediate, None))
+        for input_tens in ps.inputs:
+            if input_tens.src_tensor:
+                cmd_list.append(DMA(ps, input_tens.src_tensor, input_tens, None))
+
     sg.high_level_command_stream = cmd_list
     return sg
 
@@ -96,28 +99,28 @@
     shape = [1, 1, 1, 1]
     # u8 LUT op, should lead to DMA
     op0 = testutil.create_elemwise_op(Op.Add, "op0", shape, shape, shape)
-    set_256_lut(op0, "lut0")
+    set_256_lut(op0, "lut0", arch)
     # u8 LUT op, should lead to DMA
     op1 = testutil.create_elemwise_op(Op.Add, "op1", shape, shape, shape)
-    set_256_lut(op1, "lut1")
+    set_256_lut(op1, "lut1", arch)
     # u8 LUT op with different LUT, should lead to DMA
     op2 = testutil.create_elemwise_op(Op.Add, "op2", shape, shape, shape)
-    set_256_lut(op2, "lut2")
+    set_256_lut(op2, "lut2", arch)
     # u8 LUT op with same LUT as in op1, should not lead to DMA
     op3 = testutil.create_elemwise_op(Op.Add, "op3", shape, shape, shape)
-    set_256_lut(op3, "lut1")
+    set_256_lut(op3, "lut1", arch)
     # u8 LUT op with same LUT as in op2, should not lead to DMA
     op4 = testutil.create_elemwise_op(Op.Add, "op4", shape, shape, shape)
-    set_256_lut(op4, "lut2")
+    set_256_lut(op4, "lut2", arch)
     # 2K LUT op, should lead to DMA, and will overwrite all previous LUTs in SHRAM
     op5_2K = testutil.create_elemwise_op(Op.Add, "op5", shape, shape, shape)
-    set_2K_lut(op5_2K, "lut5")
+    set_2K_lut(op5_2K, "lut5", arch)
     # Another 2K LUT op, should lead to DMA, and will overwrite the previous LUT in SHRAM
     op6_2K = testutil.create_elemwise_op(Op.Add, "op6", shape, shape, shape)
-    set_2K_lut(op6_2K, "lut6")
+    set_2K_lut(op6_2K, "lut6", arch)
     # u8 LUT op with same LUT as in op1, should lead to DMA
     op7 = testutil.create_elemwise_op(Op.Add, "op7", shape, shape, shape)
-    set_256_lut(op7, "lut1")
+    set_256_lut(op7, "lut1", arch)
 
     op_list = [op0, op1, op2, op3, op4, op5_2K, op6_2K, op7]
     sg = process(arch, op_list)
@@ -132,7 +135,7 @@
     orig_cmd_list = filter_lut_cmds(orig_cmd_list)
 
     for (cmd, op) in zip(cmd_list, expected_dma_ops):
-        assert cmd.in_tensor == op.activation_lut
+        assert cmd.in_tensor == op.activation_lut.src_tensor
     # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses
     assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address
     assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address
@@ -151,28 +154,28 @@
     shape = [1, 1, 1, 1]
     # u8 LUT op, should lead to DMA
     op0 = testutil.create_elemwise_op(Op.Add, "op0", shape, shape, shape)
-    set_256_lut(op0, "lut0")
+    set_256_lut(op0, "lut0", arch)
     # u8 LUT op, should lead to DMA
     op1 = testutil.create_elemwise_op(Op.Add, "op1", shape, shape, shape)
-    set_256_lut(op1, "lut1")
+    set_256_lut(op1, "lut1", arch)
     # 1K LUT op with different LUT, should lead to DMA
     op2_1K = testutil.create_elemwise_op(Op.Add, "op2", shape, shape, shape)
-    set_1K_lut(op2_1K, "lut2")
+    set_1K_lut(op2_1K, "lut2", arch)
     # u8 LUT op with same LUT as in op1, should not lead to DMA
     op3 = testutil.create_elemwise_op(Op.Add, "op3", shape, shape, shape)
-    set_256_lut(op3, "lut1")
+    set_256_lut(op3, "lut1", arch)
     # 1K LUT op with same LUT as in op2, should not lead to DMA
     op4_1K = testutil.create_elemwise_op(Op.Add, "op4", shape, shape, shape)
-    set_1K_lut(op4_1K, "lut2")
+    set_1K_lut(op4_1K, "lut2", arch)
     # 1K LUT op, should lead to DMA, and will overwrite lut2
     op5_2K = testutil.create_elemwise_op(Op.Add, "op5", shape, shape, shape)
-    set_1K_lut(op5_2K, "lut5")
+    set_1K_lut(op5_2K, "lut5", arch)
     # u8 LUT op, lut0 should still be present, should not lead to DMA
     op6 = testutil.create_elemwise_op(Op.Add, "op6", shape, shape, shape)
-    set_256_lut(op6, "lut0")
+    set_256_lut(op6, "lut0", arch)
     # 1K LUT op with same LUT as in op2, should lead to DMA
     op7 = testutil.create_elemwise_op(Op.Add, "op7", shape, shape, shape)
-    set_1K_lut(op7, "lut2")
+    set_1K_lut(op7, "lut2", arch)
 
     op_list = [op0, op1, op2_1K, op3, op4_1K, op5_2K, op6, op7]
     sg = process(arch, op_list)
@@ -187,7 +190,7 @@
     # Check that only the needed DMA commands are left
     expected_dma_ops = [op0, op1, op2_1K, op5_2K, op7]
     for (cmd, op) in zip(cmd_list, expected_dma_ops):
-        assert cmd.in_tensor == op.activation_lut
+        assert cmd.in_tensor == op.activation_lut.src_tensor
     # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses
     assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address
     assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address
diff --git a/ethosu/vela/test/test_new_performance.py b/ethosu/vela/test/test_new_performance.py
new file mode 100644
index 0000000..a35905b
--- /dev/null
+++ b/ethosu/vela/test/test_new_performance.py
@@ -0,0 +1,78 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Contains unit tests for new performance estimation code
+from ethosu.vela import architecture_allocator
+from ethosu.vela import architecture_features
+from ethosu.vela import npu_performance
+from ethosu.vela import operation
+from ethosu.vela.architecture_features import resampling_mode
+from ethosu.vela.shape4d import Shape4D
+from ethosu.vela.shape4d import VolumeIterator
+from ethosu.vela.tensor import MemArea
+
+
+def test_new_performance():
+    arch = architecture_features.create_default_arch(architecture_features.Accelerator.Ethos_U55_128)
+
+    query = npu_performance.PerformanceQuery(architecture_features.NpuBlockType.ConvolutionMxN)
+    query.ifm_shape = Shape4D(1, 16, 16, 16)
+    query.ifm2_shape = Shape4D()
+    query.ifm_memory_area = MemArea.Sram
+    query.ifm_bits = 8
+    query.ofm_shape = Shape4D(1, 16, 16, 1)
+    query.ofm_memory_area = MemArea.Sram
+    query.ofm_bits = 8
+    query.const_shape = Shape4D(1, 1, 1, query.ofm_shape.depth)
+    query.const_memory_area = MemArea.OffChipFlash
+    query.kernel = operation.Kernel(1, 1, 1, 1, 1, 1, valid_padding=False)
+    query.config = architecture_allocator.find_block_config(
+        arch,
+        architecture_features.NpuBlockType.ConvolutionMxN,
+        Shape4D(1, 16, 16, 1),
+        query.ifm_shape,
+        None,
+        False,
+        8,
+        query.kernel,
+        0,
+        False,
+        resampling_mode.NONE,
+    )
+
+    print("For block Config = {}".format(query.config))
+
+    # -s to display output
+    for sub_shape in [Shape4D(1, 4, 8, 16), Shape4D(1, 8, 8, 16), Shape4D(1, 8, 16, 16), query.ofm_shape]:
+        print("\n-- Subshape = {}".format(sub_shape))
+        iterator = VolumeIterator(query.ofm_shape, sub_shape)
+        a = npu_performance.ElementAccess()
+        c = npu_performance.CycleCost()
+        for pos, shape in iterator:
+            print("\tpos = {} shape = {}".format(pos, shape))
+            ta, tc = npu_performance.measure_performance_cost(
+                arch, operation.Op.Conv2D, operation.Op.Relu, query, pos, shape
+            )
+            a += ta
+            c += tc
+            print("\t\taccess: {}".format(ta))
+            print("\t\tcycles: {}".format(tc))
+        print("\tAccess: {}".format(a))
+        print("\tCycles: {}".format(c))
+        assert c.op_macs == 4096
+
+    assert True  # Any successful result is okay
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index aa74ecf..f552b21 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -18,7 +18,6 @@
 #
 # Provides command line interface, options parsing, and network loading. Before calling the compiler driver.
 import argparse
-import ast
 import os
 import sys
 import time
@@ -38,7 +37,6 @@
 from .errors import VelaError
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
-from .scheduler import ParetoMetric
 from .supported_operators import SupportedOperators
 from .tensor import MemArea
 from .tensor import Tensor
@@ -71,9 +69,6 @@
 
     compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options)
 
-    passes_csv_file = "{0}_pass-breakdown_{1}.csv".format(output_basename, arch.system_config)
-    stats_writer.write_pass_metrics_csv(nng, passes_csv_file)
-
     summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config)
     stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch)
 
@@ -276,11 +271,6 @@
         parser.add_argument("--verbose-tensor-purpose", action="store_true", help="Verbose tensor purpose")
         parser.add_argument("--verbose-tensor-format", action="store_true", help="Verbose tensor format")
         parser.add_argument("--verbose-schedule", action="store_true", help="Verbose schedule")
-        parser.add_argument(
-            "--verbose-pareto-frontier-schedules",
-            action="store_true",
-            help="Show all schedules along the pareto frontier of optimisation criteria",
-        )
         parser.add_argument("--verbose-allocation", action="store_true", help="Verbose tensor allocation")
         parser.add_argument(
             "--verbose-high-level-command-stream", action="store_true", help="Verbose high level command stream"
@@ -293,23 +283,6 @@
         parser.add_argument(
             "--show-cpu-operations", action="store_true", help="Show the operations that fall back to the CPU"
         )
-        parser.add_argument(
-            "--cache-bias-scale-tensor",
-            type=ast.literal_eval,
-            default=True,
-            choices=[True, False],
-            help="Controls the caching of the bias & scale tensors in SRAM (default: %(default)s)",
-        )
-        parser.add_argument(
-            "--cascading",
-            type=ast.literal_eval,
-            default=True,
-            choices=[True, False],
-            help="Controls the packing of multiple passes into a cascade (default: %(default)s)",
-        )
-        parser.add_argument(
-            "--force-block-config", type=str, default="", help="Force a specific block configuration WxHxC"
-        )
         parser.add_argument("--timing", action="store_true", help="Time the compiler doing operations")
         parser.add_argument(
             "--accelerator-config",
@@ -343,32 +316,6 @@
             help="Shows a summary of all the subgraphs and their inputs and outputs",
         )
         parser.add_argument(
-            "--ifm-streaming",
-            type=ast.literal_eval,
-            default=True,
-            choices=[True, False],
-            help="Controls scheduler IFM streaming search (default: %(default)s)",
-        )
-        parser.add_argument(
-            "--block-config-limit",
-            type=int,
-            default=16,
-            help="Limit block config search space, use zero for unlimited (default: %(default)s)",
-        )
-        parser.add_argument(
-            "--pareto-metric",
-            default=ParetoMetric.BwCycMem,
-            type=lambda s: ParetoMetric[s],
-            choices=list(ParetoMetric),
-            help="Controls the calculation of the pareto metric (default: %(default)s)",
-        )
-        parser.add_argument(
-            "--recursion-limit",
-            type=int,
-            default=10000,
-            help="Set the recursion depth limit, may result in RecursionError if too low (default: %(default)s)",
-        )
-        parser.add_argument(
             "--max-block-dependency",
             type=int,
             default=architecture_features.ArchitectureFeatures.MAX_BLOCKDEP,
@@ -379,17 +326,23 @@
             ),
         )
         parser.add_argument(
-            "--nhcwb16-between-cascaded-passes",
-            type=ast.literal_eval,
-            default=True,
-            choices=[True, False],
-            help="Control if NHCWB16 or NHWC should be used in between cascaded passes (default: %(default)s)",
+            "--optimise",
+            type=lambda s: scheduler.OptimizationStrategy[s],
+            default=scheduler.OptimizationStrategy.Performance,
+            choices=list(scheduler.OptimizationStrategy),
+            help=(
+                "Set the optimisation strategy. The Size strategy results in minimal SRAM usage (does not use"
+                " arena-cache-size). The Performance strategy results in maximal performance (uses the arena-cache-size"
+                " if specified) (default: %(default)s)"
+            ),
         )
         parser.add_argument(
-            "--weight-estimation-scaling",
-            type=float,
-            default=1.0,
-            help=("Performs an additional scaling of weight compression scale estimate (default: %(default)s)"),
+            "--arena-cache-size",
+            type=int,
+            help=(
+                "Set the size of the arena cache memory area, in bytes. If specified, this option overrides the memory"
+                " mode attribute with the same name in a Vela configuration file"
+            ),
         )
         parser.add_argument(
             "--cpu-tensor-alignment",
@@ -416,13 +369,6 @@
                 if not os.access(filename, os.R_OK):
                     raise InputFileError(filename, "File not found or is not readable")
 
-        sys.setrecursionlimit(args.recursion_limit)
-
-        if args.force_block_config:
-            force_block_config = architecture_features.Block.from_string(args.force_block_config)
-        else:
-            force_block_config = None
-
         if args.cpu_tensor_alignment < 16 or args.cpu_tensor_alignment & (args.cpu_tensor_alignment - 1) != 0:
             parser.error(
                 "Invalid argument to --cpu-tensor-alignment = {} (must be greater than or equal to 16 and a power of 2)"
@@ -445,11 +391,9 @@
             system_config=args.system_config,
             memory_mode=args.memory_mode,
             accelerator_config=args.accelerator_config,
-            override_block_config=force_block_config,
-            block_config_limit=args.block_config_limit,
             max_blockdep=args.max_block_dependency,
-            weight_estimation_scaling=args.weight_estimation_scaling,
             verbose_config=args.verbose_config,
+            arena_cache_size=args.arena_cache_size,
         )
 
         compiler_options = compiler_driver.CompilerOptions(
@@ -471,13 +415,9 @@
         )
 
         scheduler_options = scheduler.SchedulerOptions(
-            use_cascading=args.cascading,
+            optimization_strategy=args.optimise,
+            sram_target=arch.arena_cache_size,
             verbose_schedule=args.verbose_schedule,
-            verbose_pareto_frontier_schedules=args.verbose_pareto_frontier_schedules,
-            use_ifm_streaming=args.ifm_streaming,
-            pareto_metric=args.pareto_metric,
-            use_nhcwb16_between_cascaded_passes=args.nhcwb16_between_cascaded_passes,
-            cache_bias_scale_tensor=args.cache_bias_scale_tensor,
         )
 
         model_reader_options = model_reader.ModelReaderOptions()
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 9a1d5a1..652d016 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -16,6 +16,7 @@
 # Description:
 # Compresses and pads the weigths. It also calculates the scales and packs with the biases.
 from collections import namedtuple
+from collections import OrderedDict
 from typing import Tuple
 
 import numpy as np
@@ -25,27 +26,85 @@
 from .architecture_features import ArchitectureFeatures
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
-from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up
-from .numeric_util import round_up_divide
 from .operation import NpuBlockType
 from .operation import Op
 from .scaling import quantise_scale
 from .scaling import reduced_quantise_scale
-from .tensor import create_equivalence_id
-from .tensor import TensorBlockTraversal
+from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
-from .tensor import TensorSubPurpose
 from ethosu import mlw_codec
 
 
 # Contains meta info for a weight compression. If two tensors have identical weight compression config,
 # then they also will have identical compressed weights.
 WeightCompressionConfig = namedtuple(
-    "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "value_id"]
+    "WeightCompressionConfig",
+    ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"],
 )
 
+WeightKey = namedtuple("WeightKey", ["core", "depth"])
+
+
+class WeightRange:
+    def __init__(self):
+        self.offset = 0
+        self.scale_bytes = 0
+        self.weight_offset = 0
+        self.weight_bytes = 0
+        self.index = 0
+
+    @property
+    def total_bytes(self):
+        return self.scale_bytes + self.weight_bytes
+
+
+class NpuWeightTensor(Tensor):
+    def __init__(self, name):
+        Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
+        self.buffer = []
+        self.max_range_bytes = 0
+        self.encoded_ranges = OrderedDict()
+        self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+        self.dtype = DataType.uint8
+
+
+class CompressedWeightCache:
+    """Global tensor weight compression cache"""
+
+    cache = {}
+
+    @staticmethod
+    def get_tensor_with_same_compression(wcc):
+        return CompressedWeightCache.cache.get(wcc)
+
+    @staticmethod
+    def add(tens):
+        # Adds the compressed weights from the tensor to the cache
+        wcc = tens.weight_compression_config
+        CompressedWeightCache.cache[wcc] = tens
+
+    @staticmethod
+    def has_tensor_with_same_compression(wcc):
+        return wcc in CompressedWeightCache.cache
+
+    @staticmethod
+    def get_unencoded_size_with_same_compression(wcc):
+        cache_obj = CompressedWeightCache.cache.get(wcc)
+        return cache_obj[1] if cache_obj else None
+
+
+def create_weight_compression_config(
+    weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation
+):
+    # Note: for an ofm block only its depth is used in weight compression.
+    # And block depth > ofm depth gives same result as block depth == ofm depth
+    block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
+    return WeightCompressionConfig(
+        npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id
+    )
+
 
 def encode_weights(
     accelerator: Accelerator,
@@ -140,185 +199,13 @@
     return data
 
 
-def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
-    # Note: for an ofm block only its depth is used in weight compression.
-    # And block depth > ofm depth gives same result as block depth == ofm depth
-    block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
-    return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, tens.value_id)
-
-
-def set_storage_shape(tens):
-    # Sets the storage shape depending on the tensor's sub purpose
-    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
-        offset = 2 * np.amax([len(x) for x in tens.compressed_values])
-        assert offset % 16 == 0
-    else:
-        offset = tens.weight_compressed_offsets[-1]
-    tens.storage_shape = [1, 1, 1, offset]
-
-
-class CompressedWeightCache:
-    # Contains weight compressions for all weight tensors in a graph
-    def __init__(self):
-        self.cache = {}  # maps from WeightCompressionConfig to a tensor clone containing compressed weights
-
-    def has_tensor_with_same_compression(self, wcc):
-        return self.cache.get(wcc) is not None
-
-    def get_tensor_with_same_compression(self, wcc):
-        cache_obj = self.cache.get(wcc)
-        return cache_obj[0] if cache_obj else None
-
-    def get_unencoded_size_with_same_compression(self, wcc):
-        cache_obj = self.cache.get(wcc)
-        return cache_obj[1] if cache_obj else None
-
-    def add(self, tens, unencoded_size):
-        # Adds the compressed weights from the tensor to the cache
-        wcc = tens.weight_compression_config
-        # Clone the tensor to make sure that nothing related to the weight compression is modified
-        tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
-        self.cache[wcc] = (tens_clone, unencoded_size)
-
-
 def core_deinterleave(hwio, core, ncores):
     # Put weights back into OHWI
     ohwi = np.transpose(hwio, (3, 0, 1, 2))
     return ohwi[core : ohwi.shape[0] : ncores]
 
 
-# Compress the weights
-def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
-    assert tens.purpose == TensorPurpose.Weights
-
-    # Check the weight cache
-    if nng.weight_cache is None:
-        nng.weight_cache = CompressedWeightCache()
-    wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation)
-    tens.weight_compression_config = wcc
-    # Reassign equivalence id such that tensors with same weight compression get identical equivalence ids,
-    # but tensors with the same values but different compression get different equivalence ids
-    tens.equivalence_id = create_equivalence_id(wcc)
-    tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
-    if tens_cached is not None:
-        # Cache hit, copy weights from the cache
-        tens.copy_compressed_weight_info(tens_cached)
-        set_storage_shape(tens)
-        return nng.weight_cache.get_unencoded_size_with_same_compression(wcc)
-    # No cache hit, perform the compression
-    assert tens.quantization is not None
-    assert tens.quantization.scale_f32 is not None
-    assert tens.quantization.zero_point is not None
-
-    zero_point = tens.quantization.zero_point
-    quant_buf = tens.quant_values.astype(np.int64)
-
-    # Early zero-point correction
-    weights = quant_buf - zero_point
-
-    if len(weights.shape) == 2:
-        weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
-
-    compression_scales = []
-    compressed_offsets = []
-    encoded_streams = []
-    encoded_streams_substream_offsets = []
-    offset = 0
-    max_single_buffer_len = 0
-    unencoded_size = 0
-
-    ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
-    ifm_depth = weights.shape[-2]
-    if npu_block_type == NpuBlockType.ConvolutionDepthWise:
-        tens.block_traversal = TensorBlockTraversal.DepthWise
-    if npu_block_type == NpuBlockType.ConvolutionMxN:
-        # Determine which block traversal strategy has better DPU utilization
-        kernel_size = weights.shape[0] * weights.shape[1]
-        depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
-        part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
-            kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
-        )
-        if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
-            # Part-kernel first is always better for ifm depths <= 8
-            tens.block_traversal = TensorBlockTraversal.PartKernelFirst
-        else:
-            tens.block_traversal = TensorBlockTraversal.DepthFirst
-
-    is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
-    if tens.block_traversal == TensorBlockTraversal.PartKernelFirst:
-        block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
-    else:
-        block_traversal = NpuBlockTraversal.DEPTH_FIRST
-
-    if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias:
-        # Transpose Convoluion, reverse weights in H and W axes
-        weights = np.flip(weights, axis=(0, 1))
-
-    # Calculate brick size
-    brick_size = (weights.shape[0], weights.shape[1], weights.shape[2], min(tens.shape[-1], ofm_depth_step))
-    elements_in_brick = np.prod(brick_size)
-
-    # Slice weight stream up depth-ways into bricks and compress
-    full_ofm_depth = quant_buf.shape[-1]
-    for idx in range(0, full_ofm_depth, ofm_depth_step):
-        # Get the weights necessary for this brick
-        count = min(full_ofm_depth - idx, ofm_depth_step)
-        brick_weights = weights[:, :, :, idx : idx + count]
-
-        substream_offsets = [0]
-        encoded_stream = []
-
-        # For each core, deinterleave weights from the larger volume
-        # and generate separate compressed streams.
-        for core in range(0, min(arch.ncores, full_ofm_depth)):
-            core_weights = core_deinterleave(brick_weights, core, arch.ncores)
-
-            block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
-            encoded_substream = []
-            if block_depth != 0:
-                encoded_substream, raw_stream_size = encode_weights(
-                    accelerator=arch.accelerator_config,
-                    weights_volume=core_weights,
-                    dilation_xy=dilation,
-                    ifm_bitdepth=ifm_bitdepth,
-                    ofm_block_depth=block_depth,
-                    is_depthwise=is_depthwise,
-                    block_traversal=block_traversal,
-                )
-                unencoded_size += raw_stream_size
-            encoded_stream.extend(encoded_substream)
-            substream_offsets.append(len(encoded_stream))
-
-        encoded_streams.append(encoded_stream)
-        encoded_streams_substream_offsets.append(substream_offsets)
-
-        # Remember maximum encoded length for DoubleBuffering
-        max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream))
-
-        # Remember where we put it for linear addressing
-        compressed_offsets.append(offset)
-        offset += len(encoded_stream)
-        assert offset % 16 == 0
-
-        # Compression scale tracking
-        compression_scales.append(len(encoded_stream) / elements_in_brick)
-
-    # Track total length as last element of the offsets array
-    compressed_offsets.append(offset)
-
-    tens.weight_compression_scales = compression_scales
-    tens.weight_compressed_offsets = compressed_offsets
-    tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
-    tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
-    tens.compressed_values = encoded_streams
-    tens.compressed_values_substream_offsets = encoded_streams_substream_offsets
-    tens.brick_size = brick_size
-    set_storage_shape(tens)
-    nng.weight_cache.add(tens, unencoded_size)
-    return unencoded_size
-
-
-def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=False):
+def _prepare_scale_and_bias(arch, tens, rescale_for_faf):
     assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
     assert tens.format == TensorFormat.NHWC
     # the connected operator should expect a bias input unless it is a FullyConnected
@@ -381,79 +268,157 @@
     else:
         quantised_scales = [quantise_scale(scale) for scale in scales]
 
-    # pack the biases and scales
+    # If only 1 quantised scale is used, repeat that value for the length of the biases
     if len(quantised_scales) == 1:
-        # If only 1 quantised scale is used, repeat that value for the length of the biases
         quantised_scales = [quantised_scales[0]] * len(biases)
 
-    assert len(quantised_scales) == len(biases)
-    tens.element_size_bytes = 10
-    tens.compressed_values = []
-    tens.compressed_values_substream_offsets = []
-
-    total_elements = len(quantised_scales)
-    alignment_bytes = 0
-    for i in range(0, total_elements, ofm_depth_step):
-        # Extract streams from brick to generate substreams for each core
-        stream = bytearray()
-        substream_offsets = [0]
-        max_len = min(ofm_depth_step, total_elements - i)
-        for core in range(0, min(arch.ncores, max_len)):
-            core_scales = quantised_scales[i + core : i + core + max_len : arch.ncores]
-            core_biases = biases[i + core : i + core + max_len : arch.ncores]
-            for j, core_bias in enumerate(core_biases):
-                stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
-
-            # Align to 16 for start for next substream
-            remainder = (len(stream)) % 16
-            if remainder > 0:
-                stream.extend(bytearray(16 - remainder))
-                alignment_bytes += 16 - remainder
-
-            substream_offsets.append(len(stream))
-
-        # Add to compressed values with their substream offset lists to the tensor
-        tens.compressed_values.append(stream)
-        tens.compressed_values_substream_offsets.append(substream_offsets)
-
-    tens.storage_shape = [total_elements + round_up_divide(alignment_bytes, tens.element_size_bytes)]
+    return quantised_scales, biases
 
 
-def update_pass_weight_and_scale_tensors(nng, arch):
-    for sg in nng.subgraphs:
-        for ps in sg.passes:
-            tens = ps.weight_tensor
-            if tens is not None:
-                op = tens.find_npu_op()
-                if op is None:
-                    continue
-                needs_dma = tens.needs_dma()
-                if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
-                    ofm_depth_step = ps.block_config[-1]
-                else:
-                    ofm_depth_step = tens.shape[-1]
-                nng.total_npu_weights += compress_weights(
-                    arch, nng, tens, op.type.npu_block_type, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
+def encode_weight_and_scale_tensor(
+    arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
+) -> NpuWeightTensor:
+    npu_block_type = op.type.npu_block_type
+
+    wcc = create_weight_compression_config(
+        weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
+    )
+
+    tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
+    if tens_cached is not None:
+        return tens_cached
+
+    npu_tensor = NpuWeightTensor(weight_tens.name)
+    npu_tensor.weight_compression_config = wcc
+
+    # No cache hit, perform the compression
+    assert weight_tens.quantization is not None
+    assert weight_tens.quantization.scale_f32 is not None
+    assert weight_tens.quantization.zero_point is not None
+
+    zero_point = weight_tens.quantization.zero_point
+    quant_buf = weight_tens.quant_values.astype(np.int64)
+
+    # Early zero-point correction
+    weights = quant_buf - zero_point
+
+    if len(weights.shape) == 2:
+        weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
+
+    # Expect this (undilated) equivalence
+    assert kernel.height == weights.shape[0]
+    assert kernel.width == weights.shape[1]
+    # Ensure depth offsets are terminated at end of OFM shape
+    assert len(depth_offsets) > 1, "Require closed depth ranges"
+
+    ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
+    ifm_depth = weights.shape[-2]
+
+    # Default HW traversal
+    npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+
+    if npu_block_type == NpuBlockType.ConvolutionMxN:
+        # Determine which block traversal strategy has better DPU utilization
+        kernel_size = weights.shape[0] * weights.shape[1]
+        depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
+        part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
+            kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
+        )
+        if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
+            # Part-kernel first is always better for ifm depths <= 8
+            npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+
+    if op.type == Op.Conv2DBackpropInputSwitchedBias:
+        # Transpose Convoluion, reverse weights in H and W axes
+        weights = np.flip(weights, axis=(0, 1))
+
+    encoded_stream = bytearray()
+    max_single_buffer_len = 0
+    is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
+
+    # Bias & scale
+    if scale_tens:
+        quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
+        scale_tens.element_size_bytes = 10
+
+    # Slice the weight stream up depth-ways into bricks and compress
+    full_ofm_depth = quant_buf.shape[-1]
+    ofm_block_depth = block_config.ofm_block.depth
+
+    weight_range_index = 0
+    for idx, depth_offset in enumerate(depth_offsets[:-1]):
+        # Do not generate for offsets outside the OFM
+        assert depth_offset >= 0 and depth_offset < full_ofm_depth
+        depth_length = depth_offsets[idx + 1] - depth_offset
+
+        # Get the weights necessary for this brick
+        brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
+
+        buffer_start_offset = len(encoded_stream)
+
+        # For each core, deinterleave weights from the larger volume
+        # and generate separate compressed streams.
+        for core in range(0, min(arch.ncores, full_ofm_depth)):
+
+            core_block_depth = int((ofm_block_depth + arch.ncores - 1 - core) // arch.ncores)
+
+            if core_block_depth != 0:
+                key = WeightKey(core, depth_offset)
+                weight_range = WeightRange()
+                weight_range.offset = len(encoded_stream)
+                weight_range.index = weight_range_index
+                weight_range_index += 1
+
+                # Scales & biases
+                if scale_tens:
+                    scale_stream = []
+                    core_scales = quantised_scales[
+                        depth_offset + core : depth_offset + core + depth_length : arch.ncores
+                    ]
+                    core_biases = biases[depth_offset + core : depth_offset + core + depth_length : arch.ncores]
+                    for j, core_bias in enumerate(core_biases):
+                        scale_stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
+
+                    weight_range.scale_bytes = len(scale_stream)
+
+                    encoded_stream.extend(scale_stream)
+
+                    # Align to 16 for start of next substream
+                    remainder = len(encoded_stream) % 16
+                    if remainder > 0:
+                        encoded_stream.extend(bytearray(16 - remainder))
+
+                # Weights
+                core_weights = core_deinterleave(brick_weights, core, arch.ncores)
+                encoded_substream, _ = encode_weights(
+                    accelerator=arch.accelerator_config,
+                    weights_volume=core_weights,
+                    dilation_xy=kernel.dilation,
+                    ifm_bitdepth=ifm_bitdepth,
+                    ofm_block_depth=core_block_depth,
+                    is_depthwise=is_depthwise,
+                    block_traversal=npu_tensor.hw_traversal,
                 )
-                nng.total_npu_encoded_weights += tens.weight_compressed_offsets[-1]
-                nng.total_original_weights += int(tens.elements() * tens.element_size())
 
-                # Update source tensor
-                if needs_dma:
-                    src_tens = tens.get_dma_src_tensor()
-                    src_tens.shape = tens.shape
-                    src_tens.quant_values = tens.quant_values
-                    src_tens.copy_compressed_weight_info(tens)
-                    set_storage_shape(src_tens)
+                weight_range.weight_offset = len(encoded_stream) - weight_range.offset
+                weight_range.weight_bytes = len(encoded_substream)
 
-            if ps.scale_tensor is not None:
-                rescale_for_faf = False
-                if (ps.ops[-1].type in (Op.Sigmoid, Op.Tanh)) and (ps.npu_block_type != NpuBlockType.ElementWise):
-                    rescale_for_faf = True
-                calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)
-                if ps.scale_tensor.ops[0].type == Op.DMA:
-                    src_tens = ps.scale_tensor.get_dma_src_tensor()
-                    src_tens.shape = ps.scale_tensor.shape
-                    src_tens.quant_values = ps.scale_tensor.quant_values
-                    src_tens.element_size_bytes = ps.scale_tensor.element_size_bytes
-                    src_tens.copy_compressed_weight_info(ps.scale_tensor)
+                # Append encoded weights section
+                encoded_stream.extend(encoded_substream)
+                assert len(encoded_stream) % 16 == 0
+
+                # Record encoded range in weights tensor
+                npu_tensor.encoded_ranges[key] = weight_range
+
+        # Remember maximum encoded length for DoubleBuffering
+        max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
+
+    npu_tensor.buffer = encoded_stream
+    npu_tensor.max_range_bytes = max_single_buffer_len
+    npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
+    npu_tensor.format = TensorFormat.WeightsCompressed
+    npu_tensor.purpose = TensorPurpose.Weights
+    npu_tensor.mem_area = weight_tens.mem_area
+    npu_tensor.mem_type = weight_tens.mem_type
+    CompressedWeightCache.add(npu_tensor)
+    return npu_tensor