MLBEDSW-3212 Enable overlap of elementwise input/output

Enable overlap of elementwise input/output

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I6e6f11953319c843c8203bf038f96778df194332
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 23026c7..b884035 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -98,18 +98,6 @@
         self.alignment = max(self.alignment, alignment)
 
 
-def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area):
-    for ps in sg.passes:
-        if ps.placement == PassPlacement.MemoryOnly:
-            # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
-            input_tensor = ps.inputs[0]
-            output_tensor = ps.outputs[0]
-            if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
-                output_tensor, target_mem_area
-            ):
-                lr_graph.fuse_ranges(input_tensor, output_tensor)
-
-
 class LiveRangeGraph:
     def __init__(self):
         self.ranges = {}  # tens -> range
@@ -138,10 +126,79 @@
         return live_range
 
 
+def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+    if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
+        return True
+    if tens in lr_graph.ignore_tensors:
+        return True
+    if tens.name.endswith("reshape_shape_npu"):
+        # Reshape tensor, no need to allocate
+        lr_graph.ignore_tensors.add(tens)
+        return True
+    return False
+
+
+# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
+def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
+    for ps in sg.passes:
+        if ps.placement == PassPlacement.MemoryOnly:
+            # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
+            input_tensor = ps.inputs[0]
+            output_tensor = ps.outputs[0]
+            if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
+                tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
+            ):
+                lr_graph.fuse_ranges(input_tensor, output_tensor)
+        elif ps.is_element_wise:
+            merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
+
+
+# Tries to merge ifm/ofm live of elementwise op
+def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
+    elem_op = None
+    for op in ps.ops:
+        if op.type.is_elementwise_op():
+            assert elem_op is None
+            elem_op = op
+
+    if elem_op is not None and not tensor_should_be_ignored(
+        lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
+    ):
+        # Check if overwriting the inputs can be allowed
+        if elem_op.type not in (Op.SHL, Op.SHR):
+            inps = []
+            if (
+                elem_op.ifm is not None
+                and elem_op.ifm.shape != []
+                and elem_op.ifm.mem_area == target_mem_area
+                and elem_op.ifm.mem_type in target_mem_type_set
+            ):
+                inps.append(elem_op.ifm)
+            if (
+                elem_op.ifm2 is not None
+                and elem_op.ifm2.shape != []
+                and elem_op.ifm2.mem_area == target_mem_area
+                and elem_op.ifm.mem_type in target_mem_type_set
+            ):
+                inps.append(elem_op.ifm2)
+
+            if len(inps) > 0:
+                for inp in inps:
+                    # check input format, dtype, broadcasting or if there are more input consumers
+                    if (
+                        inp.format == elem_op.ofm.format
+                        and inp.dtype == elem_op.ofm.dtype
+                        and inp.shape == elem_op.ofm.shape
+                        and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
+                    ):
+                        lr_graph.fuse_ranges(inp, elem_op.ofm)
+                        break
+
+
 def extract_live_ranges_from_passes(
     sg,
     target_mem_area,
-    mark_output_tensors_overlapping_with_input_tensors=False,
+    target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)),
     ignore_subgraph_input_output_tensors=False,
 ):
     lr_graph = LiveRangeGraph()
@@ -150,50 +207,24 @@
         lr_graph.ignore_tensors.update(sg.input_tensors)
         lr_graph.ignore_tensors.update(sg.output_tensors)
 
-    def tensor_should_be_ignored(tens, target_mem_area):
-        if tens.mem_area != target_mem_area:
-            return True
-        if tens in lr_graph.ignore_tensors:
-            return True
-        if tens.name.endswith("reshape_shape_npu"):
-            # Reshape tensor, no need to allocate
-            lr_graph.ignore_tensors.add(tens)
-            return True
-        return False
-
-    # Merge only memory operations in the NPU subgraphs
+    # Try to merge live ranges of operations in the NPU subgraphs
     if sg.placement == PassPlacement.Npu:
-        merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
+        merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type)
 
     for idx, ps in enumerate(sg.passes):
         ps.time = 2 * idx
 
         time_for_pass = ps.time
 
-        for tens in ps.inputs:
-            if tensor_should_be_ignored(tens, target_mem_area):
+        for tens in ps.inputs + ps.intermediates + ps.outputs:
+            if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
                 continue
             rng = lr_graph.get_or_create_range(tens)
             rng.mark_usage(time_for_pass)
 
-        for tens in ps.intermediates:
-            if tensor_should_be_ignored(tens, target_mem_area):
-                continue
-            rng = lr_graph.get_or_create_range(tens)
-            rng.mark_usage(time_for_pass)
-
-        for tens in ps.outputs:
-            if tensor_should_be_ignored(tens, target_mem_area):
-                continue
-            rng = lr_graph.get_or_create_range(tens)
-            output_time = time_for_pass
-            if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise:
-                output_time += 1
-            rng.mark_usage(output_time)
-
     end_time = len(sg.passes) * 2
     for tens in sg.output_tensors:
-        if tensor_should_be_ignored(tens, target_mem_area):
+        if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
             continue
         rng = lr_graph.get_or_create_range(tens)
         rng.mark_usage(end_time)
@@ -205,7 +236,6 @@
     sg,
     target_mem_area,
     target_mem_type_set,
-    mark_output_tensors_overlapping_with_input_tensors=False,
     use_ifm_ofm_overlap=True,
     ignore_subgraph_input_output_tensors=False,
     lr_graph=None,
@@ -222,41 +252,17 @@
         lr_graph.ignore_tensors.update(sg.input_tensors)
         lr_graph.ignore_tensors.update(sg.output_tensors)
 
-    def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
-        if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
-            return True
-        if tens in lr_graph.ignore_tensors:
-            return True
-        if tens.name.endswith("reshape_shape_npu"):
-            # Reshape tensor, no need to allocate
-            lr_graph.ignore_tensors.add(tens)
-            return True
-        return False
-
-    def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set):
-        for ps in sg.passes:
-            if ps.placement == PassPlacement.MemoryOnly:
-                # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
-                input_tensor = ps.inputs[0]
-                output_tensor = ps.outputs[0]
-                if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
-                    tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
-                ):
-                    lr_graph.fuse_ranges(input_tensor, output_tensor)
-
-    # Merge only memory operations in the NPU subgraphs
+    # Try to merge live ranges of operations in the NPU subgraphs
     if sg.placement == PassPlacement.Npu:
-        merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set)
+        merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
 
     for cps in sg.cascaded_passes:
         cps.time = lr_graph.current_time
 
         time_for_pass = cps.time
 
-        is_element_wise = cps.is_element_wise
-
         for tens in cps.inputs:
-            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+            if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
                 continue
             rng = lr_graph.get_or_create_range(tens, allocation_alignment)
             rng.mark_usage(time_for_pass)
@@ -273,33 +279,18 @@
             # Use default allocation alignment of 16 for Npu tensors
             npu_sg = cps_primary_op.attrs["subgraph"]
             lr_graph = extract_live_ranges_from_cascaded_passes(
-                npu_sg,
-                target_mem_area,
-                target_mem_type_set,
-                mark_output_tensors_overlapping_with_input_tensors,
-                use_ifm_ofm_overlap,
-                False,
-                lr_graph,
+                npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph,
             )
             # Set the new time after handling the Npu subgraph
             time_for_pass = lr_graph.current_time
             cps.time = time_for_pass
 
-        for tens in cps.intermediates:
-            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+        for tens in cps.intermediates + cps.outputs:
+            if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
                 continue
             rng = lr_graph.get_or_create_range(tens, allocation_alignment)
             rng.mark_usage(time_for_pass)
 
-        for tens in cps.outputs:
-            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
-                continue
-            rng = lr_graph.get_or_create_range(tens, allocation_alignment)
-            output_time = time_for_pass
-            if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise:
-                output_time += 1
-            rng.mark_usage(output_time)
-
         if use_ifm_ofm_overlap:
             # fill allowed overlap for ifm and ofm tensor
             ifm_tensor = cps.passes[0].ifm_tensor
@@ -307,8 +298,8 @@
             if (
                 ifm_tensor is not None
                 and ofm_tensor is not None
-                and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set)
-                and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set)
+                and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set)
+                and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set)
             ):
                 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
                     cps
@@ -322,7 +313,7 @@
         end_time = max(end_time, rng.end_time)
 
     for tens in sg.output_tensors:
-        if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+        if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
             continue
         rng = lr_graph.get_or_create_range(tens, allocation_alignment)
         rng.mark_usage(end_time)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 41e1529..31e6383 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -480,10 +480,7 @@
     def calc_non_local_mem_usage(self):
         ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
         range_set = live_range.extract_live_ranges_from_passes(
-            self.sg,
-            self.mem_area,
-            mark_output_tensors_overlapping_with_input_tensors=True,
-            ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
+            self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
         )
         range_dict = range_set.ranges
 
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 1efcd68..9f14ec4 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -137,7 +137,6 @@
         sg,
         mem_area,
         mem_type_set,
-        mark_output_tensors_overlapping_with_input_tensors=False,
         use_ifm_ofm_overlap=use_ifm_ofm_overlap,
         ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
         lr_graph=lr_graph,