MLBEDSW-6261: Elementwise cascading

Enabled elementwise cascading for binary/single variable IFM operators.

Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com>
Change-Id: I1c0867875fdc5c4980224fb570185c11e719d5cd
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index e7105e2..09c36b9 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -18,12 +18,12 @@
 # Groups Operators in a schedule together to form Cascades.
 from .numeric_util import round_up
 from .operation import NpuBlockType
+from .operation import Op
 from .shape4d import Shape4D
 
 non_cascadable_blocks = (
     NpuBlockType.Default,
     NpuBlockType.VectorProduct,
-    NpuBlockType.ElementWise,
     NpuBlockType.ReduceSum,
 )
 
@@ -89,11 +89,13 @@
 
     def _is_cascadable(self, sched_op, cost) -> bool:
         """Checks if 'sched_op' can be cascaded"""
+
         return (
             sched_op.op_type.npu_block_type not in non_cascadable_blocks
             and cost.stripe.height < sched_op.ofm.shape.height
             and sched_op.parent_op.read_offsets[0] is None
             and sched_op.parent_op.read_offsets[1] is None
+            and self.element_wise_cascading_conformity(sched_op)
         )
 
     def _estimate_sram_usage(self, sched_op, cost) -> int:
@@ -115,6 +117,24 @@
 
         return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
 
+    @staticmethod
+    def element_wise_cascading_conformity(sched_op):
+        """Check the inputs of the op to see if it's a candidate for cascading."""
+        # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops
+
+        ifm = sched_op.parent_op.ifm
+        ifm2 = sched_op.parent_op.ifm2
+
+        if sched_op.op_type in [Op.RescaleAdd]:
+            return False
+
+        if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
+            # We cannot rule out cascadability if at least one IFM is constant
+            return Op.Const in (ifm.ops[0], ifm2.ops[0])
+        else:
+            # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
+            return True
+
     def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
         ref_cost = ref_schedule.cost_map
         fallback_cost = fallback_schedule.cost_map
@@ -260,7 +280,7 @@
                 if not self.spilling:
                     peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
 
-        # Update costing and cascde information for the ref_schedule
+        # Update costing and cascade information for the ref_schedule
         ref_schedule.cost_map = cost
         ref_schedule.cascades = cascade_map
         return ref_schedule