Revert "MLBEDSW-6961: Bypass functionality for memory ops"

This reverts commit 5060ff53f5ac2382e04a68d7772bd71a36f63845.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I8dd7e9ed8325fd2e8c17509fd9757292706f5ee7
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index e6a79ce..b33851a 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -200,25 +200,35 @@
     ofm = op.ofm
     ifm = op.ifm
 
-    # Check if ifm is subgraph ifm
+    # Check if ifm/ofm are network ifm/ofm
     ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-    # Check if ifm is produced by CPU
+    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+    # Check if ifm/ofm is produced respectively consumed by CPU
     ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
 
     # This case should be handled prior to this function
-    assert not (ifm_is_sg_ifm or ifm_is_cpu_produced)
+    assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
 
-    # Bypassed by replacing ifm with ofm
-    ofm.ops = []
-    for prev_op in ifm.ops:
-        prev_op.outputs = [ofm]
-        ofm.ops.append(prev_op)
+    if ofm_is_sg_ofm or ofm_is_cpu_consumed:
+        # Bypassed by replacing ifm with ofm
+        ofm.ops = []
+        for prev_op in ifm.ops:
+            prev_op.outputs = [ofm]
+            ofm.ops.append(prev_op)
 
-    # All ifm consumers need to use ofm as input
-    for ifm_cons in ifm.consumer_list:
-        for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
-            if cons_ifm == ifm:
-                ifm_cons.set_input_tensor(ofm, ifm_idx)
+        # All ifm consumers need to use ofm as input
+        for ifm_cons in ifm.consumer_list:
+            for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                if cons_ifm == ifm:
+                    ifm_cons.set_input_tensor(ofm, ifm_idx)
+    else:
+        # Bypassed by replacing ofm with ifm
+        for cons in ofm.consumer_list:
+            for ifm_idx, cons_ifm in enumerate(cons.inputs):
+                if cons_ifm == ofm:
+                    cons.set_input_tensor(ifm, ifm_idx)
 
 
 def move_splitsliceread_to_consumer(op, cons_op):
@@ -251,8 +261,8 @@
         DebugDatabase.add_optimised(op, op)
 
 
-def insert_copy_op_after_ifm(op):
-    tens = op.ifm
+def insert_copy_op_after_tens(tens):
+    tens_cons_list_copy = tens.consumer_list.copy()
 
     # Create a avg_pool nop op with ifm as input
     copy_tens = tens.clone()
@@ -262,7 +272,12 @@
     copy_op.set_ifm_ofm_shapes()
     copy_op.run_on_npu = True
 
-    op.set_input_tensor(copy_tens, 0)
+    # Set copy_ifm consumers
+    for tens_cons in tens_cons_list_copy:
+        if tens_cons is not None:
+            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+                if cons_inp == tens:
+                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
 
     DebugDatabase.add_optimised(tens.ops[0], copy_op)
 
@@ -271,26 +286,24 @@
     if not op.run_on_npu or op.type not in memory_only_ops:
         return op
 
-    # For the memory only operators we want to remove, the ifm tensor
-    # is replaced by the ofm tensor.
-    # But in order to to do this, the ifm can not be inputs of the sg or
-    # the ifm can not have more than one consumers.
-    # This need to be fixed prior to the removal.
+    # For the memory only operators we want to remove, tensors are removed.
+    # But in order to to do this, they cannot be outputs of the sg,
+    # this need to be fixed prior to the removal.
     # Solution is to add a avgpool NOP, to maintain the original tensor.
-    # This is also valid when reshape ifm is produced by CPU
+    # This is also valid when reshape ifm/ofm is produced respectively
+    # consumed by CPU
 
-    # Check if operator ifm is subgraph ifm
+    # Check if operator ifm/ofm are sg ifm/ofm
     ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-
-    # Check if ifm is produced by CPU
+    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+    # Check if ifm/ofm is produced respectively consumed by CPU
     ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
 
-    # Check numbers of ifm consumers - if many insert avgpool NOP
-    ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
-
-    if ifm_is_sg_ifm or ifm_is_cpu_produced or ifm_has_multiple_cons:
-        # Ifm need to persist in order to remove the memory only operator.
-        insert_copy_op_after_ifm(op)
+    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
+        insert_copy_op_after_tens(op.ifm)
 
     return op