MLBEDSW-3562: Improve blockdep calculation

Blockdep calculation can now handle different sized IFM/OFM.

Change-Id: I898a3c1c3a6778916802f3dbfa658328e5093096
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 015a8c4..741b09c 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -23,6 +23,7 @@
 from enum import IntEnum
 from typing import List
 from typing import Optional
+from typing import Tuple
 
 import numpy as np
 
@@ -745,6 +746,17 @@
     )
 
 
+def range_lists_overlap(list1: List[Optional[NpuAddressRange]], list2: List[Optional[NpuAddressRange]]) -> bool:
+    """Checks if there is any address overlap between list1 and list2"""
+    for range1 in list1:
+        if range1 is None:
+            continue
+        for range2 in list2:
+            if range2 is not None and ranges_overlap(range1, range2):
+                return True
+    return False
+
+
 def get_strides(fm: NpuFeatureMap) -> NpuShape3D:
     """Calculates STRIDE_C/Y/X"""
     if fm.strides is not None:
@@ -785,12 +797,63 @@
 def get_address_range(
     fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
 ) -> NpuAddressRange:
-    """Gets address range for (y0, x0, c0) - (y1, x1, c1)"""
+    """
+    Gets address range for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm).
+    The begin and end coordinates must be within the same tile.
+    """
     addr0 = get_address(fm, strides, y0, x0, c0)
     addr1 = get_address(fm, strides, y1, x1, c1)
     return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes())
 
 
+def get_h_ranges(
+    fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> List[NpuAddressRange]:
+    """
+    Gets address ranges for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm);
+    the begin and end coordinates must be within the same tile.
+    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+    """
+    return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
+
+
+def get_address_ranges_for_area(
+    fm: NpuFeatureMap, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> List[NpuAddressRange]:
+    """
+    Returns a list of adddress ranges that covers the area (y0, x0, c0) - (y1, x1, c1) (inclusive).
+    Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+
+    For example, for the area marked with X (in a feature map with 4 tiles) as input, this function would return
+    6 address ranges: the address ranges for 1-height areas [AAA, BBB, CC, DD, EEE, FF]
+
+        .....|....           .....|....
+     t0 ..XXX|XX.. t1     t0 ..AAA|CC.. t1
+        ..XXX|XX..           ..BBB|DD..
+        -----+----    -->    -----+----
+     t2 ..XXX|XX.. t3     t2 ..EEE|FF.. t3
+        .....|....           .....|....
+    """
+    strides = get_strides(fm)
+    height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
+    h, w, c = fm.shape
+    y2, x2, c2 = min(y1, h - 1), min(x1, w - 1), min(c1, c - 1)
+    ranges = []
+    if x0 < width_0 and y0 < height_0:
+        # Horizontal ranges for tile 0
+        ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y2, height_0 - 1), min(x2, width_0 - 1), c2))
+    if x2 >= width_0 and y0 < height_1:
+        # Horizontal ranges for tile 1
+        ranges.extend(get_h_ranges(fm, strides, y0, max(x0, width_0), c0, min(y2, height_1 - 1), x2, c2))
+    if x0 < width_0 and y2 >= height_0:
+        # Horizontal ranges for tile 2
+        ranges.extend(get_h_ranges(fm, strides, max(y0, height_0), x0, c0, y2, min(x2, width_0 - 1), c2))
+    if x2 >= width_0 and y2 >= height_1:
+        # Horizontal ranges for tile 3
+        ranges.extend(get_h_ranges(fm, strides, max(y0, height_1), max(x0, width_0), c0, y2, x2, c2))
+    return ranges
+
+
 def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]:
     """Returns 4 adddress ranges, one for every tile, None if the tile is not in use"""
     strides = get_strides(fm)
@@ -806,7 +869,7 @@
     else:
         t2 = None
     if t1 is not None and t2 is not None:
-        t3 = get_address_range(fm, strides, height_0, width_0, 0, height - 1, width - 1, depth - 1)
+        t3 = get_address_range(fm, strides, height_1, width_0, 0, height - 1, width - 1, depth - 1)
     else:
         t3 = None
     return [t0, t1, t2, t3]
@@ -934,22 +997,8 @@
 # -------------------------------------------------------------------
 
 
-def is_dependent_on_prev_op(prev_op: NpuBlockOperation, npu_op: NpuBlockOperation) -> bool:
-    """Checks if npu_op's input is dependent on prev_op's output"""
-    assert npu_op.ifm is not None
-    assert prev_op.ofm is not None
-    curr_input_ranges = get_address_ranges(npu_op.ifm)
-
-    if has_ifm2(npu_op):
-        assert npu_op.ifm2 is not None
-        curr_input_ranges.extend(get_address_ranges(npu_op.ifm2))
-    for prev_range in get_address_ranges(prev_op.ofm):
-        if prev_range is None:
-            continue
-        for curr_range in curr_input_ranges:
-            if curr_range is not None and ranges_overlap(prev_range, curr_range):
-                return True
-    return False
+def shape3d_size(shape: NpuShape3D) -> int:
+    return shape.width * shape.height * shape.depth
 
 
 def shape3d_to_rect(shape: NpuShape3D) -> Rect:
@@ -970,35 +1019,66 @@
     """Calculates the value of the BLOCKDEP register"""
     if prev_op is None:
         return 0
-    if not is_dependent_on_prev_op(prev_op, npu_op):
+    assert npu_op.ifm is not None
+    assert prev_op.ofm is not None
+    # Check if IFM or IFM2 overlaps with prev op's OFM
+    prev_ofm_ranges = get_address_ranges(prev_op.ofm)
+    ifm_ranges = get_address_ranges(npu_op.ifm)
+    ifm_overlaps = range_lists_overlap(prev_ofm_ranges, ifm_ranges)
+    if has_ifm2(npu_op):
+        assert npu_op.ifm2 is not None
+        ifm2_ranges = get_address_ranges(npu_op.ifm2)
+        ifm2_overlaps = range_lists_overlap(prev_ofm_ranges, ifm2_ranges)
+    else:
+        ifm2_overlaps = False
+    if ifm_overlaps and ifm2_overlaps:
+        # Both IFM and IFM2 overlap (should be rare)
+        return 0
+    if not ifm_overlaps and not ifm2_overlaps:
+        # No overlap between prev OFM and IFM/IFM2
         return ArchitectureFeatures.MAX_BLOCKDEP
-    if prev_op.ofm.shape != npu_op.ifm.shape:
+    if ifm2_overlaps and shape3d_size(npu_op.ifm2.shape) < shape3d_size(npu_op.ifm.shape):
+        # Prev OFM produces IFM2 which is broadcasted (this should be rare)
         return 0
     prev_block_config = prev_op.block_config
     block_config = npu_op.block_config
-    prev_ifm_block_depth = get_ifm_ofm_block_depth(arch, prev_op)
+    overlapping_fm = npu_op.ifm if ifm_overlaps else npu_op.ifm2
+    assert overlapping_fm is not None
+
+    def intersects(ifm_start_coord: Tuple, ifm_end_coord: Tuple, ofm_start_coord: Tuple, ofm_end_coord: Tuple) -> bool:
+        """Checks if the given IFM area overlaps with the given OFM area"""
+        if overlapping_fm.shape == prev_op.ofm.shape and overlapping_fm.tiles == prev_op.ofm.tiles:
+            # Common case: prev_op.ofm == op.ifm; in this case it suffices to check
+            # if the xyz coordinates overlap, which is quick and easy
+            return ArchitectureFeatures.intersects(ifm_start_coord, ifm_end_coord, ofm_start_coord, ofm_end_coord)
+        # The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
+        # In this case address comparison is needed between the two areas
+        x0, y0, c0 = ifm_start_coord
+        x1, y1, c1 = ifm_end_coord
+        ifm_ranges = get_address_ranges_for_area(overlapping_fm, y0, x0, c0, y1, x1, c1)
+        x0, y0, c0 = ofm_start_coord
+        x1, y1, c1 = ofm_end_coord
+        prev_ofm_ranges = get_address_ranges_for_area(prev_op.ofm, y0, x0, c0, y1, x1, c1)
+        return range_lists_overlap(ifm_ranges, prev_ofm_ranges)
+
     prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
     prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
-    prev_ifm_rect = shape3d_to_rect(prev_op.ifm.shape)
     cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op)
     cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth)
     cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape)
     cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape)
     cur_padLT = (0, 0) if npu_op.padding is None else (npu_op.padding.left, npu_op.padding.top)
-    blockdep = arch.calc_block_dep(
-        prev_ifm_rect,
+    return arch.calc_block_dep(
         prev_ofm_rect,
-        prev_ifm_block_depth,
         prev_ofm_block,
-        to_kernel(prev_op.kernel),
         cur_ifm_rect,
         cur_ofm_rect,
         cur_ifm_block_depth,
         cur_ofm_block,
         to_kernel(npu_op.kernel),
         cur_padLT,
+        intersects=intersects,
     )
-    return blockdep
 
 
 # -------------------------------------------------------------------