MLBEDSW-7312: Refactoring bypass_memory_only_ops

- The logic when bypassing memory only ops is
complicated and it still does not fix all corner cases.
- This patch simplifies the logic by always bypassing
the op by replacing the IFM with the OFM. If that is not
possible the memory only op is changed to an memcpy op.
- The bypassing was previously done in two steps but
is now reduced to one.

Change-Id: I545dd65e0ec77c70be479a5ada2d277cac3a027c
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index e8d5ac6..8b24eaf 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,7 +27,6 @@
 from .errors import UnsupportedFeatureError
 from .errors import VelaError
 from .operation import Op
-from .operation_util import create_memcpy
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
 from .tensor import QuantizationParameters
@@ -40,10 +39,6 @@
     Op.Identity,
 )
 
-# Ops that are dependent that the original ifm tensor shape is not changed
-# by the bypass memory op function
-original_ifm_shape_ops = (Op.Mean,)
-
 
 def _avoid_nhcwb16_for_concat(tens):
     # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -209,50 +204,6 @@
     return op
 
 
-def bypass_need_to_keep_ofm_shape(op):
-    # Check if ifm must be replaced by ofm (rank is changed or the op that follow must have original ifm shape)
-    ifm_replaced_by_ofm = any(
-        ofm_cons is not None and ofm_cons.type in original_ifm_shape_ops for ofm_cons in op.ofm.consumer_list
-    ) or len(op.ifm.shape) != len(op.ofm.shape)
-    return ifm_replaced_by_ofm
-
-
-def bypass_memory_only_ops(op):
-    assert op.type in memory_only_ops
-    ofm = op.ofm
-    ifm = op.ifm
-
-    # Check if ifm/ofm are network ifm/ofm
-    ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
-    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
-    # Check if ifm/ofm is produced respectively consumed by CPU
-    ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
-    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
-    # This case should be handled prior to this function
-    assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
-
-    if (ifm.shape != ofm.shape) and (ofm_is_sg_ofm or ofm_is_cpu_consumed or bypass_need_to_keep_ofm_shape(op)):
-        # Bypassed by replacing ifm with ofm
-        ofm.ops = []
-        for prev_op in ifm.ops:
-            prev_op.outputs = [ofm]
-            ofm.ops.append(prev_op)
-
-        # All ifm consumers need to use ofm as input
-        for ifm_cons in ifm.consumer_list:
-            for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
-                if cons_ifm == ifm:
-                    ifm_cons.set_input_tensor(ofm, ifm_idx)
-    else:
-        # Bypassed by replacing ofm with ifm
-        for cons in ofm.consumer_list:
-            for ifm_idx, cons_ifm in enumerate(cons.inputs):
-                if cons_ifm == ofm:
-                    cons.set_input_tensor(ifm, ifm_idx)
-
-
 def move_splitsliceread_to_consumer(op, cons_op):
     assert op.type == Op.SplitSliceRead
 
@@ -282,117 +233,62 @@
         DebugDatabase.add_optimised(op, op)
 
 
-def insert_copy_op_before_op(op):
-    # Create a memcpy op with ifm as input
-    tens = op.ifm
-    copy_tens = tens.clone()
-    copy_op = create_memcpy(f"{tens.name}_memcpy")
-    copy_op.add_input_tensor(tens)
-    copy_op.set_output_tensor(copy_tens)
-    copy_op.set_ifm_ofm_shapes()
-
-    op.set_input_tensor(copy_tens, 0)
-
-    DebugDatabase.add_optimised(op, copy_op)
-
-
-def insert_copy_op_after_tens(tens):
-    tens_cons_list_copy = tens.consumer_list.copy()
-
-    # Create a mempcy op with ifm as input
-    copy_tens = tens.clone()
-    copy_op = create_memcpy(tens.name + "_memcpy")
-    copy_op.add_input_tensor(tens)
-    copy_op.set_output_tensor(copy_tens)
-    copy_op.set_ifm_ofm_shapes()
-    copy_op.run_on_npu = True
-
-    # Set copy_ifm consumers
-    for tens_cons in tens_cons_list_copy:
-        if tens_cons is not None:
-            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
-                if cons_inp == tens:
-                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
-
-    DebugDatabase.add_optimised(tens.ops[0], copy_op)
-
-
-def fix_sg_input_output(op, arch, nng):
+def bypass_memory_only_ops(op, arch, nng):
     if not op.run_on_npu or op.type not in memory_only_ops:
         return op
 
-    prev_op = op.ifm.ops[0]
-    while prev_op is not None and prev_op.run_on_npu and prev_op.type in memory_only_ops:
-        # Current op is preceded by another memory only op.
-        # Replace current op's ifm with the preceding op's ifm. By doing
-        # this the preceding op is removed from current path.
-        next_prev_op = prev_op.ifm.ops[0]
-        if next_prev_op is not None and next_prev_op.run_on_npu and next_prev_op.type in memory_only_ops:
-            # Preceding op also have a preceding memory only op
-            prev_op = next_prev_op
-        else:
-            op.set_input_tensor(prev_op.ifm, 0)
-            break
-
-    # For the memory only operators we want to remove, tensors are removed.
-    # But in order to to do this, they cannot be outputs of the sg,
-    # this need to be fixed prior to the removal.
-    # Solution is to add a avgpool NOP, to maintain the original tensor.
-    # This is also valid when reshape ifm/ofm is produced respectively
-    # consumed by CPU
-
-    # Rare case: original_ifm_shape_ops contain ops that are dependent
-    # that the original ifm tensor shape is not changed by the bypass memory
-    # function. If the memory only op ifm is subgraph ifm/ifm is cpu produced
-    # or the ifm is consumed by many, then there is a need to insert an avgpool
-    # NOP before the original_ifm_shape_ops. Also note that the NOP is only inserted
-    # before original_ifm_shape_ops. The above is also true when the memory only
-    # op change the rank between the IFM and OFM.
+    # Memory only operators can be completely removed if there is a one to one
+    # connection. The reshape OFM can be connected to the previous op.
     #
-    # Below is an example showing the case when there is a need for an AVG NOP
-    # when RESHAPE is bypassed by replacing IFM with OFM.
+    #                Bypassed to
+    #                    --->
+    #       1x6x6x10             1x6x6x10
+    #         ADD                  ADD
+    #          |          ------->  |
+    #       1x6x6x10      |      1x20x3x6
+    #        RESHAPE      |        MEAN
+    #          | ---------|
+    #       1x20x3x10
+    #         MEAN
     #
-    #                Converts to           And in bypass_memory
-    #                    --->                     --->
-    #    -----ADD-----           -----ADD-----           -----ADD-----
-    #    |           |           |           |           |           |
-    # 1x6x6x10    1x6x6x10   1x6x6x10    1x6x6x10    1x6x6x10    1x6x6x10
-    #  RESHAPE      MEAN     AVG POOL      MEAN      AVG POOL      MEAN
-    #    |                       |           |           |
-    # 1x20x3x6               1x6x6x10                1x20x3x6
-    #   MEAN                  RESHAPE                  MEAN
-    #                            |
-    #                        1x20x3x6
-    #                          MEAN
+    # In the above the ADD OFM = RESHAPE IFM is removed and replaced by
+    # the RESHAPE OFM.
+    #
+    # Then there are two cases when bypassing is not possible. One is when
+    # the IFM is produced by the CPU. This tensor must be preserved. It
+    # cannot be removed from the graph. The other case is when the IFM has
+    # multiple consumers, then it is not possible to just bypass the op and
+    # there is a need for a DMA (nop).
+    #
+    #                Converts to
+    #                    --->
+    #       1x6x6x10                1x6x6x10
+    #    -----ADD-----           -----ADD-----
+    #    |           |           |           |
+    # 1x6x6x10    1x6x6x10   1x6x6x10    1x6x6x10
+    #  RESHAPE      MEAN       DMA OP      MEAN
+    #    |                       |
+    # 1x20x3x6               1x20x3x6
+    #   MEAN                   MEAN
+    #
+    # If the DMA IFM and DMA OFM ends up in the same memory area
+    # the DMA op will be removed when the cmd stream is generated.
+
     ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
-
-    # Check if operator ifm/ofm are sg ifm/ofm
-    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
-    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
-    # Check if ifm/ofm is produced respectively consumed by CPU
     ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
-    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
 
-    if bypass_need_to_keep_ofm_shape(op):
-        # Bypass need to keep OFM shape
-        if ifm_has_multiple_cons:
-            # Rare case:
-            # IFM need to persist due to multiple consumers and copy op is needed
-            # OFM will replace IFM for the memory only op
-            insert_copy_op_before_op(op)
-            # One copy added so no need to check for another copy further down
-            return op
-        elif not (ofm_is_sg_ofm or ofm_is_cpu_consumed):
-            # Only one consumer and OFM is not subgraph output or cpu consumed,
-            # safe to replace ifm.shape by ofm.shape
-            # IFM can then replace OFM for the memory only op and no copy op is needed
-            op.ifm.shape = op.ofm.shape
-
-    # Special case when when OFM is sg_ofm or cpu_consumed
-    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
-        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
-        insert_copy_op_after_tens(op.ifm)
+    if ifm_has_multiple_cons or ifm_is_cpu_produced:
+        # Convert to a memcpy op
+        op.type = Op.Memcpy
+        DebugDatabase.add_optimised(op, op)
+    else:
+        # Bypass op
+        ofm = op.ofm
+        ifm = op.ifm
+        ofm.ops = []
+        for prev_op in ifm.ops:
+            prev_op.outputs = [ofm]
+            ofm.ops.append(prev_op)
 
     return op
 
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 3a49309..a1cbb3e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -35,7 +35,6 @@
 from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import convert_depthwise_to_conv
 from .graph_optimiser_util import convert_to_lut
-from .graph_optimiser_util import fix_sg_input_output
 from .graph_optimiser_util import memory_only_ops
 from .graph_optimiser_util import move_splitsliceread_to_consumer
 from .graph_optimiser_util import needed_total_padding
@@ -1362,11 +1361,6 @@
     return op
 
 
-def remove_memory_only_ops(op, arch):
-    if op.run_on_npu and op.type in memory_only_ops:
-        bypass_memory_only_ops(op)
-
-
 def fuse_activation_function_with_prev(op, arch, nng):
     # if op is a no-op: attempts to move the activation function to the preceding op
     if not op.attrs.get("is_nop", False) or op.activation is None:
@@ -1954,22 +1948,17 @@
             rewrite_unsupported=False,
         )
 
-    # Handle sg input output
+    # Bypass or rewrite memory only operators
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
             nng,
             sg,
             arch,
             [],
-            [fix_sg_input_output],
+            [bypass_memory_only_ops],
             rewrite_unsupported=False,
         )
 
-    # Removal of memory only operators
-    for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
-        sg.refresh_after_modification()
-
     # Rewrite of operators
     op_rewrite_list = [
         set_tensor_equivalence,