MLBEDSW-3951 Consider reshaping in pass packing

Consider reshaping in pass packing, when desiding if
operators can be packed.
For the cases where there is a reshape between ops
they cannot be fused.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I8f2833b3fff156e9633ce0189d1d0df9109a6622
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index a95e383..c973b9c 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -150,7 +150,7 @@
         # ops_set
         npu_pre_ops,
         # incompatible_pack_flags
-        PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise,
+        PassFlags.Cpu | PassFlags.MemoryOnly,
         # flags_to_set
         PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
         # flags_to_clear
@@ -296,21 +296,9 @@
                         for inp in reversed(curr_op.inputs):
                             if inp is None:
                                 continue
-                            can_pack = True
-                            if len(inp.ops) == 1:
-                                next_op = inp.ops[0]
-                                for outp in next_op.outputs:
-                                    consumers = outp.consumers()
-                                    if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
-                                        can_pack = False
-                                        break
+                            if can_pack(inp, curr_op):
+                                to_process.append((inp.ops[0], inp))
                             else:
-                                can_pack = False
-
-                            if can_pack:
-                                to_process.append((next_op, inp))
-                            else:
-                                assert inp is not None
                                 input_set.add(inp)
 
                         break
@@ -469,6 +457,27 @@
 
         return None
 
+    def can_pack(inp, curr_op):
+        if len(inp.ops) == 1:
+            next_op = inp.ops[0]
+            for outp in next_op.outputs:
+                consumers = outp.consumers()
+                if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
+                    return False
+
+            # There cannot be any reshaping between next_op ofm and corresponding curr_op ifm
+            if len(curr_op.ifm_shapes) != 0 and len(next_op.ofm_shapes) != 0:
+                if inp == curr_op.ifm and next_op.ofm_shapes[0] != curr_op.ifm_shapes[0]:
+                    return False
+                elif (
+                    curr_op.ifm2 is not None and inp == curr_op.ifm2 and next_op.ofm_shapes[0] != curr_op.ifm_shapes[1]
+                ):
+                    return False
+        else:
+            return False
+
+        return True
+
     def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list):
         if inp_to_add in inp_set:
             if inp_refcnts[inp_to_add] == 0: