MLBEDSW-6975: Updated bypass functionality

- The previous patch the always replaced ifm with ofm
introduced unnecessary avg pool ops for some cases.
That patch has been reverted and this is a new solution.

- Replace ifm with ofm for the following condition:

a) Ops that are dependent that the original ifm tensor
shape is not changed by the bypass memory op function.
b) When the memory op has different IFM and OFM rank.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I16a023e169ae64c5db46f6f88516a5e1ca7ed7ef
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index b33851a..e2ee06b 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -39,6 +39,10 @@
     Op.Identity,
 )
 
+# Ops that are dependent that the original ifm tensor shape is not changed
+# by the bypass memory op function
+original_ifm_shape_ops = (Op.Mean,)
+
 
 def _avoid_nhcwb16_for_concat(tens):
     # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -195,6 +199,14 @@
     return op
 
 
+def bypass_need_to_keep_ofm_shape(op):
+    # Check if ifm must be replaced by ofm (rank is changed or the op that follow must have original ifm shape)
+    ifm_replaced_by_ofm = any(
+        ofm_cons is not None and ofm_cons.type in original_ifm_shape_ops for ofm_cons in op.ofm.consumer_list
+    ) or len(op.ifm.shape) != len(op.ofm.shape)
+    return ifm_replaced_by_ofm
+
+
 def bypass_memory_only_ops(op):
     assert op.type in memory_only_ops
     ofm = op.ofm
@@ -211,7 +223,7 @@
     # This case should be handled prior to this function
     assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
 
-    if ofm_is_sg_ofm or ofm_is_cpu_consumed:
+    if (ifm.shape != ofm.shape) and (ofm_is_sg_ofm or ofm_is_cpu_consumed or bypass_need_to_keep_ofm_shape(op)):
         # Bypassed by replacing ifm with ofm
         ofm.ops = []
         for prev_op in ifm.ops:
@@ -261,6 +273,20 @@
         DebugDatabase.add_optimised(op, op)
 
 
+def insert_copy_op_before_op(op):
+    # Create a avg_pool nop op with ifm as input
+    tens = op.ifm
+    copy_tens = tens.clone()
+    copy_op = create_avgpool_nop(f"{tens.name}_avgpool")
+    copy_op.add_input_tensor(tens)
+    copy_op.set_output_tensor(copy_tens)
+    copy_op.set_ifm_ofm_shapes()
+
+    op.set_input_tensor(copy_tens, 0)
+
+    DebugDatabase.add_optimised(op, copy_op)
+
+
 def insert_copy_op_after_tens(tens):
     tens_cons_list_copy = tens.consumer_list.copy()
 
@@ -293,6 +319,31 @@
     # This is also valid when reshape ifm/ofm is produced respectively
     # consumed by CPU
 
+    # Rare case: original_ifm_shape_ops contain ops that are dependent
+    # that the original ifm tensor shape is not changed by the bypass memory
+    # function. If the memory only op ifm is subgraph ifm/ifm is cpu produced
+    # or the ifm is consumed by many, then there is a need to insert an avgpool
+    # NOP before the original_ifm_shape_ops. Also note that the NOP is only inserted
+    # before original_ifm_shape_ops. The above is also true when the memory only
+    # op change the rank between the IFM and OFM.
+    #
+    # Below is an example showing the case when there is a need for an AVG NOP
+    # when RESHAPE is bypassed by replacing IFM with OFM.
+    #
+    #                Converts to           And in bypass_memory
+    #                    --->                     --->
+    #    -----ADD-----           -----ADD-----           -----ADD-----
+    #    |           |           |           |           |           |
+    # 1x6x6x10    1x6x6x10   1x6x6x10    1x6x6x10    1x6x6x10    1x6x6x10
+    #  RESHAPE      MEAN     AVG POOL      MEAN      AVG POOL      MEAN
+    #    |                       |           |           |
+    # 1x20x3x6               1x6x6x10                1x20x3x6
+    #   MEAN                  RESHAPE                  MEAN
+    #                            |
+    #                        1x20x3x6
+    #                          MEAN
+    ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
+
     # Check if operator ifm/ofm are sg ifm/ofm
     ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
     ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
@@ -301,6 +352,20 @@
     ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
     ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
 
+    if bypass_need_to_keep_ofm_shape(op):
+        # Bypass need to keep OFM shape
+        if ifm_has_multiple_cons:
+            # Rare case:
+            # IFM need to persist due to multiple consumers and copy op is needed
+            # OFM will replace IFM for the memory only op
+            insert_copy_op_before_op(op)
+        elif not (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+            # Only one consumer and OFM is not subgraph output or cpu consumed,
+            # safe to replace ifm.shape by ofm.shape
+            # IFM can then replace OFM for the memory only op and no copy op is needed
+            op.ifm.shape = op.ofm.shape
+
+    # Special case when when OFM is sg_ofm or cpu_consumed
     if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
         # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
         insert_copy_op_after_tens(op.ifm)