MLBEDSW-4807 Elementwise IFM/OFM overlap

Reinstated allowing the IFM and OFM tensor to overlap for Elementwise
operations.

Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: Ide6db7781f3ca7a36c8ff9e3efdc7943a7bf6d7f
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index b687a9e..2795b66 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -20,7 +20,6 @@
 
 import numpy as np
 
-from .nn_graph import PassPlacement
 from .operation import Op
 from .tensor import MemArea
 from .tensor import MemType
@@ -167,98 +166,40 @@
     return False
 
 
-# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
-def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
-    for ps in sg.passes:
-        if ps.placement == PassPlacement.MemoryOnly:
-            # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
-            input_tensor = ps.inputs[0]
-            output_tensor = ps.outputs[0]
-            if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
-                tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
-            ):
-                lr_graph.fuse_ranges(input_tensor, output_tensor)
-        elif ps.is_element_wise:
-            merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
+def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+    # Tries to merge ifm/ofm live ranges of elementwise op
+    if sched_op.op_type.is_elementwise_op():
+        elem_op = sched_op.parent_op
+        if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set):
+            # Check if overwriting the inputs can be allowed
+            if elem_op.type not in (Op.SHL, Op.SHR):
+                inps = []
+                if (
+                    elem_op.ifm is not None
+                    and elem_op.ifm.shape != []
+                    and elem_op.ifm.mem_area == target_mem_area
+                    and elem_op.ifm.mem_type in target_mem_type_set
+                ):
+                    inps.append(elem_op.ifm)
+                if (
+                    elem_op.ifm2 is not None
+                    and elem_op.ifm2.shape != []
+                    and elem_op.ifm2.mem_area == target_mem_area
+                    and elem_op.ifm.mem_type in target_mem_type_set
+                ):
+                    inps.append(elem_op.ifm2)
 
-
-# Tries to merge ifm/ofm live of elementwise op
-def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
-    elem_op = None
-    for op in ps.ops:
-        if op.type.is_elementwise_op():
-            assert elem_op is None
-            elem_op = op
-
-    if elem_op is not None and not tensor_should_be_ignored(
-        lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
-    ):
-        # Check if overwriting the inputs can be allowed
-        if elem_op.type not in (Op.SHL, Op.SHR):
-            inps = []
-            if (
-                elem_op.ifm is not None
-                and elem_op.ifm.shape != []
-                and elem_op.ifm.mem_area == target_mem_area
-                and elem_op.ifm.mem_type in target_mem_type_set
-            ):
-                inps.append(elem_op.ifm)
-            if (
-                elem_op.ifm2 is not None
-                and elem_op.ifm2.shape != []
-                and elem_op.ifm2.mem_area == target_mem_area
-                and elem_op.ifm.mem_type in target_mem_type_set
-            ):
-                inps.append(elem_op.ifm2)
-
-            if len(inps) > 0:
-                for i, inp in enumerate(inps):
-                    # check input format, dtype, broadcasting or if there are more input consumers
-                    if (
-                        inp.format == elem_op.ofm.format
-                        and inp.dtype == elem_op.ofm.dtype
-                        and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
-                        and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
-                    ):
-                        lr_graph.fuse_ranges(inp, elem_op.ofm)
-                        break
-
-
-def extract_live_ranges_from_passes(
-    sg, target_mem_area, target_mem_type_set=None, ignore_subgraph_input_output_tensors=False,
-):
-    lr_graph = LiveRangeGraph()
-
-    if ignore_subgraph_input_output_tensors:
-        lr_graph.ignore_tensors.update(sg.input_tensors)
-        lr_graph.ignore_tensors.update(sg.output_tensors)
-
-    if target_mem_type_set is None:
-        target_mem_type_set = set((MemType.Scratch, MemType.Scratch_fast))
-
-    # Try to merge live ranges of operations in the NPU subgraphs
-    if sg.placement == PassPlacement.Npu:
-        merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
-
-    for idx, ps in enumerate(sg.passes):
-        ps.time = 2 * idx
-
-        time_for_pass = ps.time
-
-        for tens in ps.inputs + ps.intermediates + ps.outputs:
-            if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
-                continue
-            rng = lr_graph.get_or_create_range(tens)
-            rng.mark_usage(time_for_pass)
-
-    end_time = len(sg.passes) * 2
-    for tens in sg.output_tensors:
-        if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
-            continue
-        rng = lr_graph.get_or_create_range(tens)
-        rng.mark_usage(end_time)
-
-    return lr_graph
+                if len(inps) > 0:
+                    for i, inp in enumerate(inps):
+                        # check input format, dtype, broadcasting or if there are more input consumers
+                        if (
+                            inp.format == elem_op.ofm.format
+                            and inp.dtype == elem_op.ofm.dtype
+                            and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
+                            and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
+                        ):
+                            lr_graph.fuse_ranges(inp, elem_op.ofm)
+                            break
 
 
 def extract_live_ranges_from_cascaded_passes(
@@ -280,10 +221,6 @@
         lr_graph.ignore_tensors.update(sg.input_tensors)
         lr_graph.ignore_tensors.update(sg.output_tensors)
 
-    # Try to merge live ranges of operations in the NPU subgraphs
-    if sg.placement == PassPlacement.Npu:
-        merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
-
     for cps in sg.cascaded_passes:
         cps.time = lr_graph.current_time
 
@@ -347,7 +284,7 @@
             rng = lr_graph.get_or_create_range(tens)
             rng.mark_usage(sg_time)
 
-    for sched_op, op_info in sg.schedule.cost_map.items():
+    for _, op_info in sg.schedule.cost_map.items():
         for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
             if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)):
                 rng = lr_graph.get_or_create_range(tensor)
@@ -360,6 +297,8 @@
 def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
     time_for_cascade = {}
     for sched_op in sg.sched_ops:
+        merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set)
+
         op_info = sg.schedule.cost_map[sched_op]
         cascade = op_info.cascade
         cascade_info = sg.schedule.cascades.get(cascade, None)