MLBEDSW-3643: Refactor blockdep calculation

Moved blockdep calculation and other helper functions for
code generation to a separate file.

Change-Id: I2f8ccea478654272ebf42217fc5c1800e9ad177a
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index f7dcc8c..354ab12 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -192,6 +192,7 @@
     SubKernelMax = Block(8, 8, 65536)
 
     DEFAULT_CONFIG = "internal-default"
+    MAX_BLOCKDEP = 3
 
     def __init__(
         self,
@@ -442,143 +443,6 @@
 
         return Block(ifm_block_width, ifm_block_height, ifm_block_depth)
 
-    @staticmethod
-    def intersects(start_a, end_a, start_b, end_b):
-        start_x = max(start_a[0], start_b[0])
-        end_x = min(end_a[0], end_b[0])
-        start_y = max(start_a[1], start_b[1])
-        end_y = min(end_a[1], end_b[1])
-        start_z = max(start_a[2], start_b[2])
-        end_z = min(end_a[2], end_b[2])
-        return ((end_x - start_x) > 0) and ((end_y - start_y) > 0) and ((end_z - start_z) > 0)
-
-    # Block job dependency:
-    # Does the VOLUME of IFMs for block job B(0) overlap with VOLUME of OFMs block jobs A(8,9,10)
-    #
-    #  A                    | B
-    # ----------------------+------------------
-    # .... 3,4,5,6,7,8,9,10 | 0,1,2,3,4,5,6,8 10 < JOB NUMBER
-    #               |<------->| dependency offset
-    #
-    MAX_BLOCKDEP = 3
-
-    # Get the coordinates of a block offset from either the end (negative)
-    # or the start (zero or positive) of the given 3d area
-    def get_offset_block_coords(self, area: Rect, block: Block, offset):
-        size = area.size()
-        # Dimensions of the region, in blocks
-        width_blocks = round_up_divide(size.width, block.width)
-        height_blocks = round_up_divide(size.height, block.height)
-        depth_blocks = round_up_divide(size.depth, block.depth)
-        total_blocks = width_blocks * height_blocks * depth_blocks
-        if offset < 0:
-            index = total_blocks + offset
-        else:
-            index = offset
-
-        if index >= total_blocks:
-            return None
-
-        # Coordinates of the indexed block
-        coord_z = block.depth * (index % depth_blocks)
-        coord_y = block.height * (index // (depth_blocks * width_blocks))
-        coord_x = block.width * ((index // depth_blocks) % width_blocks)
-
-        return (coord_x + area.x, coord_y + area.y, coord_z + area.z)
-
-    def get_first_job_input_volume(
-        self, ifm: Rect, ofm: Rect, ifm_block_depth, ofm_block: Block, kernel: Kernel, padLT, block_offset
-    ):
-        # Get ifm block size (jobs are invisibly decomposed into subkernels)
-        ifm_block = self.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, self.ofm_block_max)
-        ifm_depth_blocks = round_up_divide(ifm.size().depth, ifm_block_depth)
-
-        # Which OFM block are we calculating
-        ofm_coord = self.get_offset_block_coords(ofm, ofm_block, block_offset // ifm_depth_blocks)
-        if ofm_coord is None:
-            return None
-
-        # Coordinate of the source IFM block
-        ifm_coord_x = max(0, ofm_coord[0] * kernel.stride.x - padLT[0])
-        ifm_coord_y = max(0, ofm_coord[1] * kernel.stride.y - padLT[1])
-        ifm_coord_z = ifm.z + (block_offset % ifm_depth_blocks) * ifm_block.depth
-
-        # IFM block that will be sampled for the FIRST+block_offset job in the next operator's OFM
-        start_coord = (ifm_coord_x, ifm_coord_y, ifm_coord_z)
-        end_coord = (
-            start_coord[0] + ifm_block.width,
-            start_coord[1] + ifm_block.height,
-            start_coord[2] + ifm_block.depth,
-        )
-        return (start_coord, end_coord, 1)  # start, end, total jobs
-
-    def get_prev_job_output_volume(self, ofm: Rect, ofm_block: Block, block_offset):
-        assert block_offset >= 0
-
-        # Get OFM block's volume coordinates
-        start_coord = self.get_offset_block_coords(ofm, ofm_block, -1 - block_offset)
-        if start_coord is None:
-            return None
-        end_coord = (
-            start_coord[0] + ofm_block.width,
-            start_coord[1] + ofm_block.height,
-            start_coord[2] + ofm_block.depth,
-        )
-        return (start_coord, end_coord, 1)  # start, end, total jobs for this OFM block
-
-    def calc_block_dep(
-        self,
-        prev_ofm: Rect,
-        prev_ofm_block: Block,
-        ifm: Rect,
-        ofm: Rect,
-        ifm_block_depth,
-        ofm_block: Block,
-        kernel: Kernel,
-        padLT,
-        intersects,
-    ):
-        blockdep = ArchitectureFeatures.MAX_BLOCKDEP
-
-        # Iterate over the next BLOCKDEP inputs, checking to see if a sliding window
-        # of IFM area overlaps with any previous OFM block generation.
-        elapsed_jobs = 0
-        for forward_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
-            # This is the IFM block we want to sample from
-            in_area = self.get_first_job_input_volume(
-                ifm, ofm, ifm_block_depth, ofm_block, kernel, padLT, forward_offset
-            )
-            if in_area is None:
-                break
-
-            # Try several previous-OFM blocks in the past (they still might comprise multiple IFM jobs)
-            outstanding_jobs = 0
-            for block_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
-                # This is the OFM block being generated by the previous op
-                out_area = self.get_prev_job_output_volume(prev_ofm, prev_ofm_block, block_offset)
-                if out_area is None:
-                    break
-
-                # Block dependency is the max number of allowed outstanding jobs
-                # in the pipeline. Selected by determining how many jobs occur
-                # in between two operators' overlapping OFM->IFM block volumes
-                if intersects(in_area[0], in_area[1], out_area[0], out_area[1]):
-                    break
-                # Early exit if no intersections and we've seen enough jobs in the pipeline
-                elif outstanding_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
-                    break
-
-                # This OFM had this many jobs (accumulate over multiple OFM blocks)
-                outstanding_jobs += out_area[2]
-
-            blockdep = min(blockdep, elapsed_jobs + outstanding_jobs)
-            elapsed_jobs += in_area[2]
-            # Early exit if no intersections and we've seen enough jobs in the pipeline
-            if elapsed_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
-                break
-
-        return blockdep
-
     def is_spilling_enabled(self):
         """
         Spilling is a feature that allows the Ethos-U to use a dedicated SRAM as a cache for various types of data
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index d17f1e5..6c7fdc1 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -20,6 +20,7 @@
 from . import extract_npu_subgraphs
 from . import graph_optimiser
 from . import high_level_command_stream_generator
+from . import high_level_command_to_npu_op
 from . import insert_dma
 from . import live_range
 from . import lut
@@ -27,7 +28,6 @@
 from . import npu_performance
 from . import npu_serialisation
 from . import pass_packing
-from . import register_command_stream_generator
 from . import scheduler
 from . import tensor_allocation
 from . import weight_compressor
@@ -289,7 +289,7 @@
             nng, sg, arch, options.verbose_high_level_command_stream
         )
         lut.optimize_high_level_cmd_stream(sg, arch)
-        register_command_stream_generator.generate_register_command_stream_for_sg(
+        high_level_command_to_npu_op.generate_register_command_stream_for_sg(
             nng, sg, arch, options.verbose_register_command_stream
         )
         scratch_tens, scratch_fast_tens, flash_tens = npu_serialisation.serialise_npu_subgraph_into_tensors(
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index efd8a03..7db4931 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -32,7 +32,6 @@
 from .api import NpuElementWiseOp
 from .api import NpuElementWiseOperation
 from .api import NpuFeatureMap
-from .api import NpuKernel
 from .api import NpuLayout
 from .api import NpuOperation
 from .api import NpuPadding
@@ -46,15 +45,20 @@
 from .architecture_features import ArchitectureFeatures
 from .architecture_features import Block
 from .data_type import DataType
+from .debug_database import DebugDatabase
 from .high_level_command_stream import Box
 from .high_level_command_stream import Command
 from .high_level_command_stream import CommandType
 from .high_level_command_stream import DMA
 from .high_level_command_stream import NpuStripe
-from .operation import Kernel
 from .operation import NpuBlockType
 from .operation import Op
 from .operation import Operation
+from .register_command_stream_generator import generate_command_stream
+from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
+from .register_command_stream_util import is_dma_op
+from .register_command_stream_util import to_npu_kernel
+from .register_command_stream_util import UNARY_ELEMWISE_OPS
 from .tensor import MemType
 from .tensor import Tensor
 from .tensor import TensorBlockTraversal
@@ -62,14 +66,10 @@
 from .tensor import TensorPurpose
 
 
-unary_elementwise_ops = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
-
-
 class BasePointerIndex(IntEnum):
     WeightTensor = 0  # base address index for the Weight tensor
     ScratchTensor = 1  # base address index for the Scratch_tensor in the TensorArena
     ScratchFastTensor = 2  # base address for the Scratch_fast_tensor
-    Mem2Mem = (1 << 8) | (3 << 0)  # base address slot for memory 2 memory transfer
 
 
 dtype_map = {
@@ -102,20 +102,6 @@
 }
 
 
-def to_npu_kernel(kernel: Kernel) -> NpuKernel:
-    """Converts the given internally used kernel object to NpuKernel (of public API)"""
-    return NpuKernel(
-        kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
-    )
-
-
-def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
-    """Converts the given public API object to Kernel (used internally)"""
-    if kernel is None:
-        return Kernel(1, 1)
-    return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
-
-
 def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
     if ifm_shape == []:
         # Scalar needs to be in IFM2
@@ -412,7 +398,7 @@
     assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
     elemwise_op = elementwise_op_map[op.type]
     npu_op = NpuElementWiseOperation(elemwise_op)
-    if elemwise_op not in unary_elementwise_ops:
+    if elemwise_op not in UNARY_ELEMWISE_OPS:
         if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
             # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
             cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
@@ -452,7 +438,7 @@
     """Converts the command to NpuDmaOperation"""
     src_region = get_region(cmd.in_tensor, arch)
     if cmd.out_tensor.purpose == TensorPurpose.LUT:
-        dest_region = BasePointerIndex.Mem2Mem
+        dest_region = BASE_PTR_INDEX_MEM2MEM
     else:
         dest_region = get_region(cmd.out_tensor, arch)
 
@@ -492,3 +478,28 @@
     # add a link to the high level command for debugging purposes
     npu_op.cmd = cmd
     return npu_op
+
+
+def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
+    """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
+    # Convert high level command stream to list of NpuOperation
+    npu_op_list = []
+    npu_op_to_cmd = dict()  # map from npu op to high level command
+    for cmd in sg.high_level_command_stream:
+        if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
+            print("Warning: Skipping register command stream generation for", cmd.ps)
+        else:
+            npu_op = convert_command_to_npu_op(cmd, arch)
+            npu_op_list.append(npu_op)
+            npu_op_to_cmd[npu_op] = cmd
+    # Generate register commands
+    stream_id = DebugDatabase.add_stream(sg)
+    DebugDatabase.set_stream_offset(sg, 0)  # Default to zero, can only set during file writing
+
+    def add_to_debug_db(npu_op: NpuOperation, offset: int):
+        """Adds info to the debug database"""
+        if not is_dma_op(npu_op):
+            cmd = npu_op_to_cmd[npu_op]
+            DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
+
+    sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 741b09c..d4947b1 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -18,16 +18,13 @@
 # all the register settings. Calculates dependencies between commands and inserts wait operations. And generates a bit
 # stream suitable for interpretation by the Ethos-U processor.
 from collections import defaultdict
-from collections import namedtuple
 from enum import Enum
 from enum import IntEnum
 from typing import List
 from typing import Optional
-from typing import Tuple
 
 import numpy as np
 
-from . import numeric_util
 from . import scaling
 from .api import NpuAccelerator
 from .api import NpuActivation
@@ -57,10 +54,8 @@
 from .architecture_features import ArchitectureFeatures
 from .architecture_features import Block
 from .architecture_features import create_default_arch
-from .architecture_features import Rect
 from .architecture_features import SharedBufferArea
 from .architecture_features import SHRAMElements
-from .debug_database import DebugDatabase
 from .ethos_u55_regs.ethos_u55_regs import acc_format
 from .ethos_u55_regs.ethos_u55_regs import activation
 from .ethos_u55_regs.ethos_u55_regs import cmd0
@@ -69,17 +64,20 @@
 from .ethos_u55_regs.ethos_u55_regs import pooling_mode
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .ethos_u55_regs.ethos_u55_regs import rounding
-from .high_level_command_stream import CommandType
-from .high_level_command_to_npu_op import convert_command_to_npu_op
-from .high_level_command_to_npu_op import to_kernel
-from .high_level_command_to_npu_op import unary_elementwise_ops
 from .numeric_util import quantise_float32
 from .numeric_util import round_away_zero
 from .numeric_util import round_up_to_int
 from .operation import NpuBlockType
-from .range_set import AccessDirection
-from .range_set import MemoryAccessSet
-from .range_set import MemoryRangeSet
+from .register_command_stream_util import calc_blockdep
+from .register_command_stream_util import get_dma_memory_accesses
+from .register_command_stream_util import get_op_memory_accesses
+from .register_command_stream_util import get_strides
+from .register_command_stream_util import get_wait_dependency
+from .register_command_stream_util import has_ifm2
+from .register_command_stream_util import is_dma_op
+from .register_command_stream_util import to_kernel
+from .register_command_stream_util import UNARY_ELEMWISE_OPS
+from .register_command_stream_util import Watermark
 from .shared_buffer_allocation import find_suitable_block_configs
 from .shared_buffer_allocation import shared_buffer_allocation_for_npu_op
 from .shared_buffer_allocation import SharedBufferAllocation
@@ -203,13 +201,6 @@
 # -------------------------------------------------------------------
 
 
-class BasePointerIndex(IntEnum):
-    WeightTensor = 0  # base address index for the Weight tensor
-    ScratchTensor = 1  # base address index for the Scratch_tensor in the TensorArena
-    ScratchFastTensor = 2  # base address for the Scratch_fast_tensor
-    Mem2Mem = (1 << 8) | (3 << 0)  # base address slot for memory 2 memory transfer
-
-
 # TODO: Replace with definitions from ethos_u55_regs
 class IFM2Broadcast(IntEnum):
     BroadcastHdim = 1 << 0
@@ -275,16 +266,6 @@
     return quantise_float32(value, scale, zp)
 
 
-def has_ifm2(npu_op: NpuBlockOperation) -> bool:
-    """Checks if op has non-scalar IFM2"""
-    return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
-
-
-def is_dma_op(npu_op: NpuOperation) -> bool:
-    """Checks if op is a DMA operation"""
-    return npu_op.op_type == NpuOperationType.Dma
-
-
 def generate_padding(emit: CommandStreamEmitter, padding: NpuPadding):
     """Generates IFM_PAD registers"""
     emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, padding.top)
@@ -584,6 +565,15 @@
     return shared_buffer_allocation_for_npu_op(arch, npu_op, block_type, ifm_resampling_mode)
 
 
+def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark):
+    """Generates KERNEL_WAIT/DMA_WAIT"""
+    if cmd_waits.npu >= 0:
+        emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu)
+
+    if cmd_waits.dma >= 0:
+        emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma)
+
+
 def generate_common(
     emit: CommandStreamEmitter,
     npu_op: NpuBlockOperation,
@@ -735,353 +725,6 @@
 
 
 # -------------------------------------------------------------------
-# ADDRESSING/STRIDES (helper functions)
-# -------------------------------------------------------------------
-
-
-def ranges_overlap(range1: NpuAddressRange, range2: NpuAddressRange) -> bool:
-    """Checks if the ranges overlap"""
-    return range1.region == range2.region and numeric_util.overlaps(
-        range1.address, range1.address + range1.length, range2.address, range2.address + range2.length
-    )
-
-
-def range_lists_overlap(list1: List[Optional[NpuAddressRange]], list2: List[Optional[NpuAddressRange]]) -> bool:
-    """Checks if there is any address overlap between list1 and list2"""
-    for range1 in list1:
-        if range1 is None:
-            continue
-        for range2 in list2:
-            if range2 is not None and ranges_overlap(range1, range2):
-                return True
-    return False
-
-
-def get_strides(fm: NpuFeatureMap) -> NpuShape3D:
-    """Calculates STRIDE_C/Y/X"""
-    if fm.strides is not None:
-        return fm.strides
-    elem_size = fm.data_type.size_in_bytes()
-    if fm.layout == NpuLayout.NHWC:
-        stride_c = elem_size
-        stride_x = fm.shape.depth * stride_c
-        stride_y = fm.shape.width * stride_x
-    else:
-        stride_x = 16 * elem_size
-        stride_c = stride_x * fm.shape.width
-        stride_y = elem_size * fm.shape.width * numeric_util.round_up(fm.shape.depth, 16)
-    return NpuShape3D(depth=stride_c, height=stride_y, width=stride_x)
-
-
-def get_address(fm: NpuFeatureMap, strides: NpuShape3D, y: int, x: int, c: int) -> int:
-    """Returns address of given coordinate"""
-    t = 0
-    BRICK = 16
-    stride_c = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHWC else strides.depth
-    stride_x = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHCWB16 else strides.width
-    if x >= fm.tiles.width_0:
-        x -= fm.tiles.width_0
-        t = 1
-        if y >= fm.tiles.height_1:
-            y -= fm.tiles.height_1
-            t += 2
-    elif y >= fm.tiles.height_0:
-        y -= fm.tiles.height_0
-        t += 2
-    elem_size = fm.data_type.size_in_bytes()
-    return (
-        fm.tiles.addresses[t] + y * strides.height + x * stride_x + (c // BRICK) * stride_c + int(c % BRICK) * elem_size
-    )
-
-
-def get_address_range(
-    fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> NpuAddressRange:
-    """
-    Gets address range for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm).
-    The begin and end coordinates must be within the same tile.
-    """
-    addr0 = get_address(fm, strides, y0, x0, c0)
-    addr1 = get_address(fm, strides, y1, x1, c1)
-    return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes())
-
-
-def get_h_ranges(
-    fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> List[NpuAddressRange]:
-    """
-    Gets address ranges for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm);
-    the begin and end coordinates must be within the same tile.
-    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
-    """
-    return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
-
-
-def get_address_ranges_for_area(
-    fm: NpuFeatureMap, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> List[NpuAddressRange]:
-    """
-    Returns a list of adddress ranges that covers the area (y0, x0, c0) - (y1, x1, c1) (inclusive).
-    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
-
-    For example, for the area marked with X (in a feature map with 4 tiles) as input, this function would return
-    6 address ranges: the address ranges for 1-height areas [AAA, BBB, CC, DD, EEE, FF]
-
-        .....|....           .....|....
-     t0 ..XXX|XX.. t1     t0 ..AAA|CC.. t1
-        ..XXX|XX..           ..BBB|DD..
-        -----+----    -->    -----+----
-     t2 ..XXX|XX.. t3     t2 ..EEE|FF.. t3
-        .....|....           .....|....
-    """
-    strides = get_strides(fm)
-    height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
-    h, w, c = fm.shape
-    y2, x2, c2 = min(y1, h - 1), min(x1, w - 1), min(c1, c - 1)
-    ranges = []
-    if x0 < width_0 and y0 < height_0:
-        # Horizontal ranges for tile 0
-        ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y2, height_0 - 1), min(x2, width_0 - 1), c2))
-    if x2 >= width_0 and y0 < height_1:
-        # Horizontal ranges for tile 1
-        ranges.extend(get_h_ranges(fm, strides, y0, max(x0, width_0), c0, min(y2, height_1 - 1), x2, c2))
-    if x0 < width_0 and y2 >= height_0:
-        # Horizontal ranges for tile 2
-        ranges.extend(get_h_ranges(fm, strides, max(y0, height_0), x0, c0, y2, min(x2, width_0 - 1), c2))
-    if x2 >= width_0 and y2 >= height_1:
-        # Horizontal ranges for tile 3
-        ranges.extend(get_h_ranges(fm, strides, max(y0, height_1), max(x0, width_0), c0, y2, x2, c2))
-    return ranges
-
-
-def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]:
-    """Returns 4 adddress ranges, one for every tile, None if the tile is not in use"""
-    strides = get_strides(fm)
-    height, width, depth = fm.shape.height, fm.shape.width, fm.shape.depth
-    height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
-    t0 = get_address_range(fm, strides, 0, 0, 0, min(height, height_0) - 1, min(width, width_0) - 1, depth - 1,)
-    if width > width_0:
-        t1 = get_address_range(fm, strides, 0, width_0, 0, min(height, height_1) - 1, width - 1, depth - 1)
-    else:
-        t1 = None
-    if height > height_0:
-        t2 = get_address_range(fm, strides, height_0, 0, 0, height - 1, min(width, width_0) - 1, depth - 1)
-    else:
-        t2 = None
-    if t1 is not None and t2 is not None:
-        t3 = get_address_range(fm, strides, height_1, width_0, 0, height - 1, width - 1, depth - 1)
-    else:
-        t3 = None
-    return [t0, t1, t2, t3]
-
-
-# -------------------------------------------------------------------
-# DMA_WAIT/KERNEL_WAIT
-# -------------------------------------------------------------------
-
-
-Watermark = namedtuple("Watermark", ["npu", "dma"])
-
-
-def memory_range_set(range: NpuAddressRange) -> MemoryRangeSet:
-    return MemoryRangeSet(range.region, range.address, range.address + range.length)
-
-
-def get_dma_memory_accesses(dma_op: NpuDmaOperation) -> MemoryAccessSet:
-    """Returns the address that are read and written by the given DMA operation"""
-    res = MemoryAccessSet()
-    res.add(memory_range_set(dma_op.src), AccessDirection.Read)
-    res.add(memory_range_set(dma_op.dest), AccessDirection.Write)
-    return res
-
-
-def get_op_memory_accesses(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> MemoryAccessSet:
-    """Returns the addresses that are read and written by the given operation"""
-    assert npu_op.ifm is not None and npu_op.ofm is not None
-    # Read addresses
-    read_ranges = get_address_ranges(npu_op.ifm)
-    if has_ifm2(npu_op):
-        assert npu_op.ifm2 is not None
-        read_ranges.extend(get_address_ranges(npu_op.ifm2))
-    read_ranges.extend(npu_op.weights)
-    read_ranges.extend(npu_op.biases)
-    if npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP:
-        address = arch.available_shram_banks(True) * arch.shram_bank_size
-        read_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=address, length=2048))
-    # Written addresses
-    write_ranges = get_address_ranges(npu_op.ofm)
-    # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks
-    uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
-    written_shram_size = arch.available_shram_banks(uses_lut) * arch.shram_bank_size
-    write_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=0, length=written_shram_size))
-
-    res = MemoryAccessSet()
-    for read_range in read_ranges:
-        if read_range is not None:
-            res.add(memory_range_set(read_range), AccessDirection.Read)
-    for write_range in write_ranges:
-        if write_range is not None:
-            res.add(memory_range_set(write_range), AccessDirection.Write)
-    return res
-
-
-def get_wait_dependency(
-    arch: ArchitectureFeatures, npu_op_list: List[NpuOperation], memory_accesses, op_index: int, watermark: Watermark
-):
-    """Used to calculate whether DMA wait or kernel wait operations are needed"""
-    npu_op = npu_op_list[op_index]
-    op_access = memory_accesses[npu_op]
-    index = op_index - 1
-
-    # NPU dependency tracking
-    npu_outstanding = -1
-    npu_ops = 0
-    npu_index = watermark.npu
-
-    # DMA dependency tracking
-    dma_outstanding = -1
-    dma_ops = 0
-    dma_index = watermark.dma
-
-    # Seek back in the command stream looking for NPU or DMA dependencies
-    # but only as far as the first dependency or the watermarks (dependencies
-    # before this point have been satisfied already).
-    # The watermark moves to after the latest element we must wait for, not
-    # the command that issues the wait.
-    # NPU->NPU dependency is handled via blockdep.
-    while (index >= npu_index) or (index >= dma_index):
-        prev_op = npu_op_list[index]
-        prev_access = memory_accesses[prev_op]
-
-        # Check NPU consuming DMA output
-        if is_dma_op(prev_op):
-            if index >= dma_index:
-                if not is_dma_op(npu_op):
-                    if (dma_outstanding == -1) and prev_access.conflicts(op_access):
-                        dma_outstanding = dma_ops
-                dma_ops += 1  # Count DMA ops in the pipeline
-                if dma_ops >= arch.max_outstanding_dma:
-                    dma_index = max(index + 1, dma_index)
-        # Check DMA consuming NPU output
-        else:
-            if index >= npu_index:
-                if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
-                    npu_outstanding = npu_ops
-                npu_ops += 1  # Count NPU ops in the pipeline
-                if npu_ops >= arch.max_outstanding_kernels:
-                    npu_index = max(index + 1, npu_index)
-
-        index -= 1
-
-    # Update DMA watermark if we didn't see any and the NPU pipeline is full
-    if (dma_ops == 0) and (npu_ops >= arch.max_outstanding_kernels):
-        dma_index = op_index
-
-    # Bring the search watermark forwards as we complete for those dependencies
-    watermark = Watermark(npu_index, dma_index)
-    outstanding = Watermark(npu_outstanding, dma_outstanding)
-
-    return watermark, outstanding
-
-
-def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark):
-    if cmd_waits.npu >= 0:
-        emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu)
-
-    if cmd_waits.dma >= 0:
-        emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma)
-
-
-# -------------------------------------------------------------------
-# BLOCKDEP
-# -------------------------------------------------------------------
-
-
-def shape3d_size(shape: NpuShape3D) -> int:
-    return shape.width * shape.height * shape.depth
-
-
-def shape3d_to_rect(shape: NpuShape3D) -> Rect:
-    return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1)
-
-
-def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperation) -> int:
-    # Note: NOT equivalent to the normal ifm block depth calculation since
-    # it takes into account 'depthless' block operations by returning full
-    # depth
-    if npu_op.op_type == NpuOperationType.Conv2D:
-        res = arch.calc_ifm_block_depth(npu_op.ifm.shape.depth, npu_op.ifm.data_type.size_in_bits())
-        return res
-    return npu_op.ofm.shape.depth
-
-
-def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int:
-    """Calculates the value of the BLOCKDEP register"""
-    if prev_op is None:
-        return 0
-    assert npu_op.ifm is not None
-    assert prev_op.ofm is not None
-    # Check if IFM or IFM2 overlaps with prev op's OFM
-    prev_ofm_ranges = get_address_ranges(prev_op.ofm)
-    ifm_ranges = get_address_ranges(npu_op.ifm)
-    ifm_overlaps = range_lists_overlap(prev_ofm_ranges, ifm_ranges)
-    if has_ifm2(npu_op):
-        assert npu_op.ifm2 is not None
-        ifm2_ranges = get_address_ranges(npu_op.ifm2)
-        ifm2_overlaps = range_lists_overlap(prev_ofm_ranges, ifm2_ranges)
-    else:
-        ifm2_overlaps = False
-    if ifm_overlaps and ifm2_overlaps:
-        # Both IFM and IFM2 overlap (should be rare)
-        return 0
-    if not ifm_overlaps and not ifm2_overlaps:
-        # No overlap between prev OFM and IFM/IFM2
-        return ArchitectureFeatures.MAX_BLOCKDEP
-    if ifm2_overlaps and shape3d_size(npu_op.ifm2.shape) < shape3d_size(npu_op.ifm.shape):
-        # Prev OFM produces IFM2 which is broadcasted (this should be rare)
-        return 0
-    prev_block_config = prev_op.block_config
-    block_config = npu_op.block_config
-    overlapping_fm = npu_op.ifm if ifm_overlaps else npu_op.ifm2
-    assert overlapping_fm is not None
-
-    def intersects(ifm_start_coord: Tuple, ifm_end_coord: Tuple, ofm_start_coord: Tuple, ofm_end_coord: Tuple) -> bool:
-        """Checks if the given IFM area overlaps with the given OFM area"""
-        if overlapping_fm.shape == prev_op.ofm.shape and overlapping_fm.tiles == prev_op.ofm.tiles:
-            # Common case: prev_op.ofm == op.ifm; in this case it suffices to check
-            # if the xyz coordinates overlap, which is quick and easy
-            return ArchitectureFeatures.intersects(ifm_start_coord, ifm_end_coord, ofm_start_coord, ofm_end_coord)
-        # The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
-        # In this case address comparison is needed between the two areas
-        x0, y0, c0 = ifm_start_coord
-        x1, y1, c1 = ifm_end_coord
-        ifm_ranges = get_address_ranges_for_area(overlapping_fm, y0, x0, c0, y1, x1, c1)
-        x0, y0, c0 = ofm_start_coord
-        x1, y1, c1 = ofm_end_coord
-        prev_ofm_ranges = get_address_ranges_for_area(prev_op.ofm, y0, x0, c0, y1, x1, c1)
-        return range_lists_overlap(ifm_ranges, prev_ofm_ranges)
-
-    prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
-    prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
-    cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op)
-    cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth)
-    cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape)
-    cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape)
-    cur_padLT = (0, 0) if npu_op.padding is None else (npu_op.padding.left, npu_op.padding.top)
-    return arch.calc_block_dep(
-        prev_ofm_rect,
-        prev_ofm_block,
-        cur_ifm_rect,
-        cur_ofm_rect,
-        cur_ifm_block_depth,
-        cur_ofm_block,
-        to_kernel(npu_op.kernel),
-        cur_padLT,
-        intersects=intersects,
-    )
-
-
-# -------------------------------------------------------------------
 # PRINT
 # -------------------------------------------------------------------
 
@@ -1209,7 +852,7 @@
         emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale, op_to_scale=op_to_scale
     )
     # Elementwise op specific
-    if npu_op.sub_op_type not in unary_elementwise_ops:
+    if npu_op.sub_op_type not in UNARY_ELEMWISE_OPS:
         # Binary operation; generate IFM2 registers
         assert npu_op.ifm2 is not None
         has_scalar = npu_op.ifm2_scalar is not None
@@ -1253,9 +896,15 @@
 
 
 def generate_command_stream(
-    emit: CommandStreamEmitter, npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, add_to_debug_db=None
-):
-    """Generates register commands for the given list of NPU operations"""
+    npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None,
+) -> List[int]:
+    """
+    Generates register commands for the given list of NPU operations.
+    Returns Ethos-U instructions, as a list of 32-bit integers.
+    """
+    emit = CommandStreamEmitter()
+    if verbose:
+        print_operations(npu_op_list)
     # Calculate memory accesses for every operation
     memory_accesses = {}
     for npu_op in npu_op_list:
@@ -1285,39 +934,17 @@
             add_to_debug_db(npu_op, emit.offset)
     # Fill in final part of command stream:
     emit.cmd_do_operation(cmd0.NPU_OP_STOP, param=0xFFFF)
-
-
-def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
-    """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
-    # Convert high level command stream to list of NpuOperation
-    npu_op_list = []
-    npu_op_to_cmd = dict()  # map from npu op to high level command
-    for cmd in sg.high_level_command_stream:
-        if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
-            print("Warning: Skipping register command stream generation for", cmd.ps)
-        else:
-            npu_op = convert_command_to_npu_op(cmd, arch)
-            npu_op_list.append(npu_op)
-            npu_op_to_cmd[npu_op] = cmd
-    if verbose:
-        print_operations(npu_op_list)
-    # Generate register commands
-    stream_id = DebugDatabase.add_stream(sg)
-    DebugDatabase.set_stream_offset(sg, 0)  # Default to zero, can only set during file writing
-    emit = CommandStreamEmitter()
-
-    def add_to_debug_db(npu_op: NpuOperation, offset: int):
-        """Adds info to the debug database"""
-        if not is_dma_op(npu_op):
-            cmd = npu_op_to_cmd[npu_op]
-            DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
-
-    generate_command_stream(emit, npu_op_list, arch, add_to_debug_db)
-    sg.register_command_stream = emit.to_list()
+    res = emit.to_list()
     if verbose:
         emit.print_cmds()
         print("number of commands", len(emit.cmd_stream))
-        print("command stream length in words", len(sg.register_command_stream))
+        print("command stream length in words", len(res))
+    return res
+
+
+# -------------------------------------------------------------------
+# EXTERNAL API
+# -------------------------------------------------------------------
 
 
 def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> List[NpuShape3D]:
@@ -1342,7 +969,5 @@
     :return Ethos-U instructions, as a list of 32-bit integers
     """
     accelerator = Accelerator.from_npu_accelerator(npu_accelerator)
-    emit = CommandStreamEmitter()
     arch = create_default_arch(accelerator)
-    generate_command_stream(emit, npu_op_list, arch)
-    return emit.to_list()
+    return generate_command_stream(npu_op_list, arch, verbose=False)
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
new file mode 100644
index 0000000..ca7e6bc
--- /dev/null
+++ b/ethosu/vela/register_command_stream_util.py
@@ -0,0 +1,543 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Utility functions for code generation
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+
+from . import numeric_util
+from .api import NpuActivationOp
+from .api import NpuAddressRange
+from .api import NpuBlockOperation
+from .api import NpuDmaOperation
+from .api import NpuElementWiseOp
+from .api import NpuFeatureMap
+from .api import NpuKernel
+from .api import NpuLayout
+from .api import NpuOperation
+from .api import NpuOperationType
+from .api import NpuPadding
+from .api import NpuShape3D
+from .architecture_features import ArchitectureFeatures
+from .architecture_features import Block
+from .architecture_features import Rect
+from .operation import Kernel
+from .operation import PointXYZ
+from ethosu.vela.range_set import AccessDirection
+from ethosu.vela.range_set import MemoryAccessSet
+from ethosu.vela.range_set import MemoryRangeSet
+
+# base address slot for memory to memory transfer
+BASE_PTR_INDEX_MEM2MEM = int((1 << 8) | (3 << 0))
+
+
+UNARY_ELEMWISE_OPS = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
+
+
+def to_npu_kernel(kernel: Kernel) -> NpuKernel:
+    """Converts the given internally used kernel object to NpuKernel (of public API)"""
+    return NpuKernel(
+        kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
+    )
+
+
+def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
+    """Converts the given public API object to Kernel (used internally)"""
+    if kernel is None:
+        return Kernel(1, 1)
+    return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
+
+
+def has_ifm2(npu_op: NpuBlockOperation) -> bool:
+    """Checks if op has non-scalar IFM2"""
+    return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
+
+
+def is_dma_op(npu_op: NpuOperation) -> bool:
+    """Checks if op is a DMA operation"""
+    return npu_op.op_type == NpuOperationType.Dma
+
+
+def shape3d_size(shape: NpuShape3D) -> int:
+    return shape.width * shape.height * shape.depth
+
+
+def shape3d_to_rect(shape: NpuShape3D) -> Rect:
+    return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1)
+
+
+# -------------------------------------------------------------------
+# ADDRESSING/STRIDES (helper functions)
+# -------------------------------------------------------------------
+
+
+def ranges_overlap(range1: NpuAddressRange, range2: NpuAddressRange) -> bool:
+    """Checks if the ranges overlap"""
+    return range1.region == range2.region and numeric_util.overlaps(
+        range1.address, range1.address + range1.length, range2.address, range2.address + range2.length
+    )
+
+
+def range_lists_overlap(list1: List[Optional[NpuAddressRange]], list2: List[Optional[NpuAddressRange]]) -> bool:
+    """Checks if there is any address overlap between list1 and list2"""
+    for range1 in list1:
+        if range1 is None:
+            continue
+        for range2 in list2:
+            if range2 is not None and ranges_overlap(range1, range2):
+                return True
+    return False
+
+
+def get_strides(fm: NpuFeatureMap) -> NpuShape3D:
+    """Calculates STRIDE_C/Y/X"""
+    if fm.strides is not None:
+        return fm.strides
+    elem_size = fm.data_type.size_in_bytes()
+    if fm.layout == NpuLayout.NHWC:
+        stride_c = elem_size
+        stride_x = fm.shape.depth * stride_c
+        stride_y = fm.shape.width * stride_x
+    else:
+        stride_x = 16 * elem_size
+        stride_c = stride_x * fm.shape.width
+        stride_y = elem_size * fm.shape.width * numeric_util.round_up(fm.shape.depth, 16)
+    return NpuShape3D(depth=stride_c, height=stride_y, width=stride_x)
+
+
+def get_address(fm: NpuFeatureMap, strides: NpuShape3D, y: int, x: int, c: int) -> int:
+    """Returns address of given coordinate"""
+    t = 0
+    BRICK = 16
+    stride_c = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHWC else strides.depth
+    stride_x = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHCWB16 else strides.width
+    if x >= fm.tiles.width_0:
+        x -= fm.tiles.width_0
+        t = 1
+        if y >= fm.tiles.height_1:
+            y -= fm.tiles.height_1
+            t += 2
+    elif y >= fm.tiles.height_0:
+        y -= fm.tiles.height_0
+        t += 2
+    elem_size = fm.data_type.size_in_bytes()
+    return (
+        fm.tiles.addresses[t] + y * strides.height + x * stride_x + (c // BRICK) * stride_c + int(c % BRICK) * elem_size
+    )
+
+
+def get_address_range(
+    fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> NpuAddressRange:
+    """
+    Gets address range for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm).
+    The begin and end coordinates must be within the same tile.
+    """
+    addr0 = get_address(fm, strides, y0, x0, c0)
+    addr1 = get_address(fm, strides, y1, x1, c1)
+    return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes())
+
+
+def get_h_ranges(
+    fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> List[NpuAddressRange]:
+    """
+    Gets address ranges for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm);
+    the begin and end coordinates must be within the same tile.
+    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+    """
+    return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
+
+
+def get_address_ranges_for_area(fm: NpuFeatureMap, start: PointXYZ, end: PointXYZ) -> List[NpuAddressRange]:
+    """
+    Returns a list of adddress ranges that covers the area start - end (inclusive).
+    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+
+    For example, for the area marked with X (in a feature map with 4 tiles) as input, this function would return
+    6 address ranges: the address ranges for 1-height areas [AAA, BBB, CC, DD, EEE, FF]
+
+        .....|....           .....|....
+     t0 ..XXX|XX.. t1     t0 ..AAA|CC.. t1
+        ..XXX|XX..           ..BBB|DD..
+        -----+----    -->    -----+----
+     t2 ..XXX|XX.. t3     t2 ..EEE|FF.. t3
+        .....|....           .....|....
+    """
+    strides = get_strides(fm)
+    height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
+    h, w, c = fm.shape
+    y0, x0, c0 = start.y, start.x, start.z
+    y1, x1, c1 = min(end.y, h - 1), min(end.x, w - 1), min(end.z, c - 1)
+    ranges = []
+    if x0 < width_0 and y0 < height_0:
+        # Horizontal ranges for tile 0
+        ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y1, height_0 - 1), min(x1, width_0 - 1), c1))
+    if x1 >= width_0 and y0 < height_1:
+        # Horizontal ranges for tile 1
+        ranges.extend(get_h_ranges(fm, strides, y0, max(x0, width_0), c0, min(y1, height_1 - 1), x1, c1))
+    if x0 < width_0 and y1 >= height_0:
+        # Horizontal ranges for tile 2
+        ranges.extend(get_h_ranges(fm, strides, max(y0, height_0), x0, c0, y1, min(x1, width_0 - 1), c1))
+    if x1 >= width_0 and y1 >= height_1:
+        # Horizontal ranges for tile 3
+        ranges.extend(get_h_ranges(fm, strides, max(y0, height_1), max(x0, width_0), c0, y1, x1, c1))
+    return ranges
+
+
+def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]:
+    """Returns 4 adddress ranges, one for every tile, None if the tile is not in use"""
+    strides = get_strides(fm)
+    height, width, depth = fm.shape.height, fm.shape.width, fm.shape.depth
+    height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
+    t0 = get_address_range(fm, strides, 0, 0, 0, min(height, height_0) - 1, min(width, width_0) - 1, depth - 1,)
+    if width > width_0:
+        t1 = get_address_range(fm, strides, 0, width_0, 0, min(height, height_1) - 1, width - 1, depth - 1)
+    else:
+        t1 = None
+    if height > height_0:
+        t2 = get_address_range(fm, strides, height_0, 0, 0, height - 1, min(width, width_0) - 1, depth - 1)
+    else:
+        t2 = None
+    if t1 is not None and t2 is not None:
+        t3 = get_address_range(fm, strides, height_1, width_0, 0, height - 1, width - 1, depth - 1)
+    else:
+        t3 = None
+    return [t0, t1, t2, t3]
+
+
+# -------------------------------------------------------------------
+# DMA_WAIT/KERNEL_WAIT
+# -------------------------------------------------------------------
+
+
+class Watermark(NamedTuple):
+    npu: int
+    dma: int
+
+
+def memory_range_set(range: NpuAddressRange) -> MemoryRangeSet:
+    return MemoryRangeSet(range.region, range.address, range.address + range.length)
+
+
+def get_dma_memory_accesses(dma_op: NpuDmaOperation) -> MemoryAccessSet:
+    """Returns the address that are read and written by the given DMA operation"""
+    res = MemoryAccessSet()
+    res.add(memory_range_set(dma_op.src), AccessDirection.Read)
+    res.add(memory_range_set(dma_op.dest), AccessDirection.Write)
+    return res
+
+
+def get_op_memory_accesses(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> MemoryAccessSet:
+    """Returns the addresses that are read and written by the given operation"""
+    assert npu_op.ifm is not None and npu_op.ofm is not None
+    # Read addresses
+    read_ranges = get_address_ranges(npu_op.ifm)
+    if has_ifm2(npu_op):
+        assert npu_op.ifm2 is not None
+        read_ranges.extend(get_address_ranges(npu_op.ifm2))
+    read_ranges.extend(npu_op.weights)
+    read_ranges.extend(npu_op.biases)
+    if npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP:
+        address = arch.available_shram_banks(True) * arch.shram_bank_size
+        read_ranges.append(NpuAddressRange(region=BASE_PTR_INDEX_MEM2MEM, address=address, length=2048))
+    # Written addresses
+    write_ranges = get_address_ranges(npu_op.ofm)
+    # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks
+    uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
+    written_shram_size = arch.available_shram_banks(uses_lut) * arch.shram_bank_size
+    write_ranges.append(NpuAddressRange(region=BASE_PTR_INDEX_MEM2MEM, address=0, length=written_shram_size))
+
+    res = MemoryAccessSet()
+    for read_range in read_ranges:
+        if read_range is not None:
+            res.add(memory_range_set(read_range), AccessDirection.Read)
+    for write_range in write_ranges:
+        if write_range is not None:
+            res.add(memory_range_set(write_range), AccessDirection.Write)
+    return res
+
+
+def get_wait_dependency(
+    arch: ArchitectureFeatures, npu_op_list: List[NpuOperation], memory_accesses, op_index: int, watermark: Watermark
+):
+    """Used to calculate whether DMA wait or kernel wait operations are needed"""
+    npu_op = npu_op_list[op_index]
+    op_access = memory_accesses[npu_op]
+    index = op_index - 1
+
+    # NPU dependency tracking
+    npu_outstanding = -1
+    npu_ops = 0
+    npu_index = watermark.npu
+
+    # DMA dependency tracking
+    dma_outstanding = -1
+    dma_ops = 0
+    dma_index = watermark.dma
+
+    # Seek back in the command stream looking for NPU or DMA dependencies
+    # but only as far as the first dependency or the watermarks (dependencies
+    # before this point have been satisfied already).
+    # The watermark moves to after the latest element we must wait for, not
+    # the command that issues the wait.
+    # NPU->NPU dependency is handled via blockdep.
+    while (index >= npu_index) or (index >= dma_index):
+        prev_op = npu_op_list[index]
+        prev_access = memory_accesses[prev_op]
+
+        # Check NPU consuming DMA output
+        if is_dma_op(prev_op):
+            if index >= dma_index:
+                if not is_dma_op(npu_op):
+                    if (dma_outstanding == -1) and prev_access.conflicts(op_access):
+                        dma_outstanding = dma_ops
+                dma_ops += 1  # Count DMA ops in the pipeline
+                if dma_ops >= arch.max_outstanding_dma:
+                    dma_index = max(index + 1, dma_index)
+        # Check DMA consuming NPU output
+        else:
+            if index >= npu_index:
+                if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
+                    npu_outstanding = npu_ops
+                npu_ops += 1  # Count NPU ops in the pipeline
+                if npu_ops >= arch.max_outstanding_kernels:
+                    npu_index = max(index + 1, npu_index)
+
+        index -= 1
+
+    # Update DMA watermark if we didn't see any and the NPU pipeline is full
+    if (dma_ops == 0) and (npu_ops >= arch.max_outstanding_kernels):
+        dma_index = op_index
+
+    # Bring the search watermark forwards as we complete for those dependencies
+    watermark = Watermark(npu_index, dma_index)
+    outstanding = Watermark(npu_outstanding, dma_outstanding)
+
+    return watermark, outstanding
+
+
+# -------------------------------------------------------------------
+# BLOCKDEP
+# -------------------------------------------------------------------
+
+
+def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperation) -> int:
+    # Note: NOT equivalent to the normal ifm block depth calculation since
+    # it takes into account 'depthless' block operations by returning full
+    # depth
+    if npu_op.op_type == NpuOperationType.Conv2D:
+        res = arch.calc_ifm_block_depth(npu_op.ifm.shape.depth, npu_op.ifm.data_type.size_in_bits())
+        return res
+    return npu_op.ofm.shape.depth
+
+
+def coords_intersect(start_a: PointXYZ, end_a: PointXYZ, start_b: PointXYZ, end_b: PointXYZ) -> bool:
+    """Checks if the two areas overlap"""
+    start_x = max(start_a.x, start_b.x)
+    end_x = min(end_a.x, end_b.x)
+    start_y = max(start_a.y, start_b.y)
+    end_y = min(end_a.y, end_b.y)
+    start_z = max(start_a.z, start_b.z)
+    end_z = min(end_a.z, end_b.z)
+    return ((end_x - start_x) > 0) and ((end_y - start_y) > 0) and ((end_z - start_z) > 0)
+
+
+def intersects(
+    ifm: NpuFeatureMap,
+    ifm_start_coord: PointXYZ,
+    ifm_end_coord: PointXYZ,
+    prev_ofm: NpuFeatureMap,
+    ofm_start_coord: PointXYZ,
+    ofm_end_coord: PointXYZ,
+) -> bool:
+    """Checks if the given IFM area overlaps with the given OFM area"""
+    if ifm.shape == prev_ofm.shape and ifm.tiles == prev_ofm.tiles:
+        # Common case: prev_op.ofm == op.ifm; in this case it suffices to check
+        # if the xyz coordinates overlap, which is quick and easy
+        res = coords_intersect(ifm_start_coord, ifm_end_coord, ofm_start_coord, ofm_end_coord)
+    else:
+        # The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
+        # In this case, address comparison between the two areas is needed
+        ifm_ranges = get_address_ranges_for_area(ifm, ifm_start_coord, ifm_end_coord)
+        prev_ofm_ranges = get_address_ranges_for_area(prev_ofm, ofm_start_coord, ofm_end_coord)
+        res = range_lists_overlap(ifm_ranges, prev_ofm_ranges)
+    return res
+
+
+# Block job dependency:
+# Does the VOLUME of IFMs for block job B(0) overlap with VOLUME of OFMs block jobs A(8,9,10)
+#
+#  A                    | B
+# ----------------------+------------------
+# .... 3,4,5,6,7,8,9,10 | 0,1,2,3,4,5,6,8 10 < JOB NUMBER
+#               |<------->| dependency offset
+#
+
+
+def get_offset_block_coords(area: Rect, block: Block, offset: int) -> Optional[PointXYZ]:
+    """
+    Get the coordinates of a block offset from either the end (negative)
+    or the start (zero or positive) of the given 3D area
+    """
+    size = area.size()
+    # Dimensions of the region, in blocks
+    width_blocks = numeric_util.round_up_divide(size.width, block.width)
+    height_blocks = numeric_util.round_up_divide(size.height, block.height)
+    depth_blocks = numeric_util.round_up_divide(size.depth, block.depth)
+    total_blocks = width_blocks * height_blocks * depth_blocks
+    if offset < 0:
+        index = total_blocks + offset
+    else:
+        index = offset
+
+    if index >= total_blocks:
+        return None
+
+    # Coordinates of the indexed block
+    coord_z = block.depth * (index % depth_blocks)
+    coord_y = block.height * (index // (depth_blocks * width_blocks))
+    coord_x = block.width * ((index // depth_blocks) % width_blocks)
+
+    return PointXYZ(x=coord_x + area.x, y=coord_y + area.y, z=coord_z + area.z)
+
+
+def get_first_job_input_volume(
+    arch: ArchitectureFeatures,
+    ifm: Rect,
+    ofm: Rect,
+    ifm_block_depth,
+    ofm_block: Block,
+    kernel: Kernel,
+    padding: NpuPadding,
+    block_offset: int,
+):
+    # Get ifm block size (jobs are invisibly decomposed into subkernels)
+    ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, arch.ofm_block_max)
+    ifm_depth_blocks = numeric_util.round_up_divide(ifm.size().depth, ifm_block_depth)
+
+    # Which OFM block are we calculating
+    ofm_coord = get_offset_block_coords(ofm, ofm_block, block_offset // ifm_depth_blocks)
+    if ofm_coord is None:
+        return None
+
+    # Coordinate of the source IFM block
+    ifm_coord_x = max(0, ofm_coord[0] * kernel.stride.x - padding.left)
+    ifm_coord_y = max(0, ofm_coord[1] * kernel.stride.y - padding.right)
+    ifm_coord_z = ifm.z + (block_offset % ifm_depth_blocks) * ifm_block.depth
+
+    # IFM block that will be sampled for the FIRST+block_offset job in the next operator's OFM
+    start_coord = PointXYZ(x=ifm_coord_x, y=ifm_coord_y, z=ifm_coord_z)
+    end_coord = PointXYZ(
+        x=start_coord[0] + ifm_block.width, y=start_coord[1] + ifm_block.height, z=start_coord[2] + ifm_block.depth,
+    )
+    return (start_coord, end_coord, 1)  # start, end, total jobs
+
+
+def get_prev_job_output_volume(ofm: Rect, ofm_block: Block, block_offset: int):
+    assert block_offset >= 0
+
+    # Get OFM block's volume coordinates
+    start_coord = get_offset_block_coords(ofm, ofm_block, -1 - block_offset)
+    if start_coord is None:
+        return None
+    end_coord = PointXYZ(
+        x=start_coord.x + ofm_block.width, y=start_coord.y + ofm_block.height, z=start_coord.z + ofm_block.depth,
+    )
+    return (start_coord, end_coord, 1)  # start, end, total jobs for this OFM block
+
+
+def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int:
+    """Calculates the value of the BLOCKDEP register"""
+    if prev_op is None:
+        return 0
+    assert npu_op.ifm is not None
+    assert prev_op.ofm is not None
+    # Check if IFM or IFM2 overlaps with prev op's OFM
+    prev_ofm_ranges = get_address_ranges(prev_op.ofm)
+    ifm_ranges = get_address_ranges(npu_op.ifm)
+    ifm_overlaps = range_lists_overlap(prev_ofm_ranges, ifm_ranges)
+    if has_ifm2(npu_op):
+        assert npu_op.ifm2 is not None
+        ifm2_ranges = get_address_ranges(npu_op.ifm2)
+        ifm2_overlaps = range_lists_overlap(prev_ofm_ranges, ifm2_ranges)
+    else:
+        ifm2_overlaps = False
+    if ifm_overlaps and ifm2_overlaps:
+        # Both IFM and IFM2 overlap (should be rare)
+        return 0
+    if not ifm_overlaps and not ifm2_overlaps:
+        # No overlap between prev OFM and IFM/IFM2
+        return ArchitectureFeatures.MAX_BLOCKDEP
+    if ifm2_overlaps and shape3d_size(npu_op.ifm2.shape) < shape3d_size(npu_op.ifm.shape):
+        # Prev OFM produces IFM2 which is broadcasted (this should be rare)
+        return 0
+    # Prev OFM overlaps with IFM or IFM2; calculate the blockdep
+    prev_block_config = prev_op.block_config
+    block_config = npu_op.block_config
+    overlapping_fm = npu_op.ifm if ifm_overlaps else npu_op.ifm2
+    assert overlapping_fm is not None
+
+    cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op)
+    cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth)
+    cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape)
+    cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape)
+    padding = NpuPadding(0, 0, 0, 0) if npu_op.padding is None else npu_op.padding
+    blockdep = ArchitectureFeatures.MAX_BLOCKDEP
+    kernel = to_kernel(npu_op.kernel)
+
+    prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
+    prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
+    # Iterate over the next BLOCKDEP inputs, checking to see if a sliding window
+    # of IFM area overlaps with any previous OFM block generation.
+    elapsed_jobs = 0
+    for forward_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
+        # This is the IFM block we want to sample from
+        in_area = get_first_job_input_volume(
+            arch, cur_ifm_rect, cur_ofm_rect, cur_ifm_block_depth, cur_ofm_block, kernel, padding, forward_offset
+        )
+        if in_area is None:
+            break
+
+        # Try several previous-OFM blocks in the past (they still might comprise multiple IFM jobs)
+        outstanding_jobs = 0
+        for block_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
+            # This is the OFM block being generated by the previous op
+            out_area = get_prev_job_output_volume(prev_ofm_rect, prev_ofm_block, block_offset)
+            if out_area is None:
+                break
+
+            # Block dependency is the max number of allowed outstanding jobs
+            # in the pipeline. Selected by determining how many jobs occur
+            # in between two operators' overlapping OFM->IFM block volumes
+            if intersects(overlapping_fm, in_area[0], in_area[1], prev_op.ofm, out_area[0], out_area[1]):
+                break
+            # Early exit if no intersections and we've seen enough jobs in the pipeline
+            elif outstanding_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
+                break
+
+            # This OFM had this many jobs (accumulate over multiple OFM blocks)
+            outstanding_jobs += out_area[2]
+
+        blockdep = min(blockdep, elapsed_jobs + outstanding_jobs)
+        elapsed_jobs += in_area[2]
+        # Early exit if no intersections and we've seen enough jobs in the pipeline
+        if elapsed_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
+            break
+
+    return blockdep
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index ee55962..21b048b 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -28,10 +28,10 @@
 from .architecture_features import SHRAMElements
 from .errors import VelaError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
-from .high_level_command_to_npu_op import to_kernel
 from .operation import Kernel
 from .operation import NpuBlockType
 from .range_set import MemoryRangeSet
+from .register_command_stream_util import to_kernel
 from .tensor import MemArea
 
 
diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
index 812991a..b605dfc 100644
--- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py
+++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
@@ -41,7 +41,7 @@
 from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0
 from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd1
 from ethosu.vela.register_command_stream_generator import CmdMode
-from ethosu.vela.register_command_stream_generator import get_address_ranges
+from ethosu.vela.register_command_stream_util import get_address_ranges
 
 
 def check_cmd0(cmd_stream, cmd, param):
diff --git a/ethosu/vela/test/test_register_command_generator.py b/ethosu/vela/test/test_register_command_stream_util.py
similarity index 98%
rename from ethosu/vela/test/test_register_command_generator.py
rename to ethosu/vela/test/test_register_command_stream_util.py
index 2760c86..985523f 100644
--- a/ethosu/vela/test/test_register_command_generator.py
+++ b/ethosu/vela/test/test_register_command_stream_util.py
@@ -32,8 +32,8 @@
 from ethosu.vela.architecture_features import Accelerator
 from ethosu.vela.architecture_features import create_default_arch
 from ethosu.vela.register_command_stream_generator import calc_blockdep
-from ethosu.vela.register_command_stream_generator import get_address_ranges
 from ethosu.vela.register_command_stream_generator import get_strides
+from ethosu.vela.register_command_stream_util import get_address_ranges
 from ethosu.vela.test.extapi.test_extapi_generate_commands import create_feature_map