MLBEDSW-7019: Update to elementwise cascading

- The cascade builder is using the ifm_ifm2_correct_order
function in order to decide if the operator is cascadable or not.
The problem is that this function expects a full shape or no shape
and the cascade builder did not provide that, so the operator was
reported to be non cascadable.

- The fix is to provide a full 4D shape, also refactoring
ifm_ifm2_correct_order to use 4D shape to avoid confusion
in the future.

- Refactoring code so that the scheduler can perform a
correct ifm and ifm2 swap.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I9a86c4690612f332afa428456a07e67698852495
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 9c84ba8..b042ba7 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -98,7 +98,7 @@
             and cost.stripe.height < sched_op.ofm.shape.height
             and sched_op.parent_op.read_offsets[0] is None
             and sched_op.parent_op.read_offsets[1] is None
-            and self.element_wise_cascading_conformity(sched_op)
+            and self.elementwise_cascading_correct_order(sched_op)
             and not sched_op.parent_op.type.is_resize_op()
             and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias
             and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
@@ -127,22 +127,34 @@
         return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
 
     @staticmethod
-    def element_wise_cascading_conformity(sched_op):
+    def elementwise_cascading_conformity(sched_op):
         """Check the inputs of the op to see if it's a candidate for cascading."""
 
-        ifm = sched_op.parent_op.ifm
-        ifm2 = sched_op.parent_op.ifm2
-
-        if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
+        if sched_op.parent_op.type.is_binary_elementwise_op():
             # We cannot rule out cascadability if at least one IFM is constant
+            ifm = sched_op.parent_op.ifm
+            ifm2 = sched_op.parent_op.ifm2
             ifm_const = ifm.ops != [] and ifm.ops[0].type == Op.Const
             ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
-            correct_order = ifm_ifm2_correct_order(ifm.shape, ifm2.shape)
-            return (ifm_const and (ifm.shape == ifm2.shape or not correct_order)) or (ifm2_const and correct_order)
+            return ifm_const or ifm2_const
         else:
             # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
             return True
 
+    @staticmethod
+    def elementwise_cascading_correct_order(sched_op):
+        """Check the inputs of the op to see ifm and ifm2 has correct order."""
+
+        if sched_op.parent_op.type.is_binary_elementwise_op():
+            ifm2 = sched_op.parent_op.ifm2
+            ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
+
+            # ifm_ifm2_correct_order needs full shape
+            correct_order = ifm_ifm2_correct_order(sched_op.ifm.shape, sched_op.ifm2.shape)
+            return ifm2_const and correct_order
+        else:
+            return True
+
     def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
         ref_cost = ref_schedule.cost_map
         fallback_cost = fallback_schedule.cost_map
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 202917b..228c76f 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -112,14 +112,15 @@
 }
 
 
-def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
-    if ifm_shape == []:
+def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool:
+
+    if ifm_shape is None:
         # Scalar needs to be in IFM2
         return False
-    if ifm2_shape == []:
+    if ifm2_shape is None:
         return True
 
-    for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
+    for ifm, ifm2 in zip(ifm_shape.as_list(), ifm2_shape.as_list()):
         if ifm != ifm2 and ifm == 1:
             # Broadcasted FM needs to be in IFM2
             return False
@@ -553,8 +554,8 @@
     npu_op = NpuElementWiseOperation(elemwise_op)
 
     if elemwise_op not in UNARY_ELEMWISE_OPS:
-        ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
-        ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
+        ifm_shape = None if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0]
+        ifm2_shape = None if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1]
         if cmd.reversed_operands:
             assert ifm_ifm2_correct_order(ifm_shape, ifm2_shape)
             npu_op.reversed_operands = True
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index aaabddb..c4176d9 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -242,8 +242,8 @@
     if ifm2 is None:
         ofm_shape = ifm_shape
     else:
-        in_shape = [] if ifm.shape == [] else ifm_shape.as_list()
-        in2_shape = [] if ifm2.shape == [] else ifm2_shape.as_list()
+        in_shape = None if ifm.shape == [] else ifm_shape
+        in2_shape = None if ifm2.shape == [] else ifm2_shape
         ofm_shape = ifm_shape if ifm_ifm2_correct_order(in_shape, in2_shape) else ifm2_shape
 
     ofm = Tensor(ofm_shape.as_list(), dtype, f"{op.name}_tens0")
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 208b121..79cd642 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -227,7 +227,7 @@
         # Perform an IFM swap for certain binary elementwise operators
         # in order to enable cascading, if the SchedOp conforms to
         # Elementwise cascading rules.
-        if self.op_type.is_binary_elementwise_op() and CascadeBuilder.element_wise_cascading_conformity(self):
+        if self.op_type.is_binary_elementwise_op() and CascadeBuilder.elementwise_cascading_conformity(self):
             ifm1 = ps.ifm_tensor
             ifm2 = ps.ifm2_tensor
             ofm = ps.ofm_tensor