MLBEDSW-4839: Fix issues with Elementwise IFM/OFM overlap

 - Fixed typo with not using ifm.mem_type
 - Fixed bug with using ifm1 properties when only ifm2 is a potential match
 - Removed restriction on not considering SHL and SHR for overlap
 - Removed some dead reshape code

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Id9bcc3c2b3ee9ac7b6276187d3e2f513b4acd4b5
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 2795b66..7ff1b28 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -16,6 +16,7 @@
 # Description:
 # Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
 # Can work with either a pass packed subgraph or a scheduled subgraph.
+from collections import namedtuple
 from typing import List
 
 import numpy as np
@@ -159,47 +160,45 @@
         return True
     if tens in lr_graph.ignore_tensors:
         return True
-    if tens.name.endswith("reshape_shape_npu"):
-        # Reshape tensor, no need to allocate
-        lr_graph.ignore_tensors.add(tens)
-        return True
     return False
 
 
 def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+    def _tensor_should_be_ignored(tens):
+        return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set)
+
     # Tries to merge ifm/ofm live ranges of elementwise op
     if sched_op.op_type.is_elementwise_op():
         elem_op = sched_op.parent_op
-        if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set):
+        if not _tensor_should_be_ignored(elem_op.ofm):
             # Check if overwriting the inputs can be allowed
-            if elem_op.type not in (Op.SHL, Op.SHR):
-                inps = []
-                if (
-                    elem_op.ifm is not None
-                    and elem_op.ifm.shape != []
-                    and elem_op.ifm.mem_area == target_mem_area
-                    and elem_op.ifm.mem_type in target_mem_type_set
-                ):
-                    inps.append(elem_op.ifm)
-                if (
-                    elem_op.ifm2 is not None
-                    and elem_op.ifm2.shape != []
-                    and elem_op.ifm2.mem_area == target_mem_area
-                    and elem_op.ifm.mem_type in target_mem_type_set
-                ):
-                    inps.append(elem_op.ifm2)
+            OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
+            outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
+            inps = []
+            if elem_op.ifm is not None:
+                inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
+            if elem_op.ifm2 is not None:
+                inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
 
-                if len(inps) > 0:
-                    for i, inp in enumerate(inps):
-                        # check input format, dtype, broadcasting or if there are more input consumers
-                        if (
-                            inp.format == elem_op.ofm.format
-                            and inp.dtype == elem_op.ofm.dtype
-                            and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
-                            and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
-                        ):
-                            lr_graph.fuse_ranges(inp, elem_op.ofm)
-                            break
+            # find an input tensor that can be overwritten by the output
+            for inp in inps:
+                if (
+                    # check op input and output shapes allow overlapping
+                    inp.op_shape == outp.op_shape
+                    # check input tensor is valid
+                    and inp.tens is not None
+                    and inp.tens.shape != []
+                    and not _tensor_should_be_ignored(inp.tens)
+                    # check input and output tensors are compatible
+                    and inp.tens.format == outp.tens.format
+                    and inp.tens.dtype == outp.tens.dtype
+                    # check input tensor only has one consumer
+                    and len(inp.tens.consumer_list) == 1
+                    # check output tensor only has one producer
+                    and len(outp.tens.ops) == 1
+                ):
+                    lr_graph.fuse_ranges(inp.tens, outp.tens)
+                    break
 
 
 def extract_live_ranges_from_cascaded_passes(
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 518b243..b28f4eb 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -348,8 +348,7 @@
                         ps.ifm_shapes.append(op.ifm_shapes[0])
                     elif ps.ifm_tensor == op.ifm2:
                         ps.ifm_shapes.append(op.ifm_shapes[1])
-            for op in input_ops_list + [primary_op]:
-                if op.run_on_npu:
+
                     if ps.ifm2_tensor == op.ifm:
                         ps.ifm_shapes.append(op.ifm_shapes[0])
                     elif ps.ifm2_tensor == op.ifm2:
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 3701893..fd3bf42 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -39,7 +39,7 @@
 from .tflite_mapping import BuiltinOperator
 from .tflite_mapping import datatype_inv_map
 
-# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
+# the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
 
 tflite_version = 3
 tflite_file_identifier = "TFL" + str(tflite_version)