MLBEDSW-6931: Refactoring merge elementwise ops

Change code in cascade builder to instead
use common functionality in live range.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I7bbd7ea3d1e7e085813e9d93256a54e6bab2267b
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 9b6fe63..fbb48ec 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -154,18 +154,21 @@
 
 
 def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+    if target_mem_area is None or target_mem_type_set is None:
+        return False
     if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
         return True
     return False
 
 
-def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
+def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
     def _tensor_should_be_ignored(tens):
         if tens.ifm_write_protected:
             return True
         return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
 
-    # Tries to merge ifm/ofm live ranges of elementwise op
+    # Check if possible to merge ifm/ofm live ranges of elementwise op
+    ifm_tens = None
     if sched_op.op_type.is_elementwise_op():
         elem_op = sched_op.parent_op
         if not _tensor_should_be_ignored(elem_op.ofm):
@@ -195,9 +198,22 @@
                     # check output tensor only has one producer
                     and len(outp.tens.ops) == 1
                 ):
-                    lr_graph.fuse_ranges(inp.tens, outp.tens)
+                    ifm_tens = inp.tens
                     break
 
+    return ifm_tens
+
+
+def ofm_can_reuse_ifm(sched_op, target_mem_area=None, target_mem_type_set=None):
+    ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
+    return ifm is not None
+
+
+def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
+    ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
+    if ifm:
+        lr_graph.fuse_ranges(ifm, sched_op.parent_op.ofm)
+
 
 def extract_live_ranges_from_cascaded_passes(
     sg,