MLBEDSW-6261: Elementwise cascading

Enabled elementwise cascading for binary/single variable IFM operators.

Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com>
Change-Id: I1c0867875fdc5c4980224fb570185c11e719d5cd
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index e7105e2..09c36b9 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -18,12 +18,12 @@
 # Groups Operators in a schedule together to form Cascades.
 from .numeric_util import round_up
 from .operation import NpuBlockType
+from .operation import Op
 from .shape4d import Shape4D
 
 non_cascadable_blocks = (
     NpuBlockType.Default,
     NpuBlockType.VectorProduct,
-    NpuBlockType.ElementWise,
     NpuBlockType.ReduceSum,
 )
 
@@ -89,11 +89,13 @@
 
     def _is_cascadable(self, sched_op, cost) -> bool:
         """Checks if 'sched_op' can be cascaded"""
+
         return (
             sched_op.op_type.npu_block_type not in non_cascadable_blocks
             and cost.stripe.height < sched_op.ofm.shape.height
             and sched_op.parent_op.read_offsets[0] is None
             and sched_op.parent_op.read_offsets[1] is None
+            and self.element_wise_cascading_conformity(sched_op)
         )
 
     def _estimate_sram_usage(self, sched_op, cost) -> int:
@@ -115,6 +117,24 @@
 
         return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
 
+    @staticmethod
+    def element_wise_cascading_conformity(sched_op):
+        """Check the inputs of the op to see if it's a candidate for cascading."""
+        # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops
+
+        ifm = sched_op.parent_op.ifm
+        ifm2 = sched_op.parent_op.ifm2
+
+        if sched_op.op_type in [Op.RescaleAdd]:
+            return False
+
+        if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
+            # We cannot rule out cascadability if at least one IFM is constant
+            return Op.Const in (ifm.ops[0], ifm2.ops[0])
+        else:
+            # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
+            return True
+
     def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
         ref_cost = ref_schedule.cost_map
         fallback_cost = fallback_schedule.cost_map
@@ -260,7 +280,7 @@
                 if not self.spilling:
                     peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
 
-        # Update costing and cascde information for the ref_schedule
+        # Update costing and cascade information for the ref_schedule
         ref_schedule.cost_map = cost
         ref_schedule.cascades = cascade_map
         return ref_schedule
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 0009f6c..4a41edd 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -34,6 +34,18 @@
         for i in range(len(self.start_coord)):
             assert self.start_coord[i] <= self.end_coord[i]
 
+    @staticmethod
+    def wrap(a, b):
+        """Wrap broadcasted tensor boxes in order to
+        prevent out of bounds during box creation"""
+        tmp = [0, 0, 0, 0]
+        for i, val in enumerate(a):
+            if int(val) != 0:
+                tmp[i] = a[i]
+                if a[i] >= b[i] and b[i] != 0:
+                    tmp[i] = a[i] % b[i]
+        return Shape4D(tmp)
+
     def transform_with_strides_and_skirt(
         self,
         strides: List[int],
@@ -45,6 +57,7 @@
         split_offset: Optional[Shape4D] = None,
         split_shape: Optional[Shape4D] = None,
         upscaling_factor: int = 1,
+        op_type=None,
     ):
         new_start_coord = list(self.start_coord)
         new_end_coord = list(self.end_coord)
@@ -115,6 +128,15 @@
                 new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
                 new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1)
 
+        # Wrap the IFMs of broadcasted binary elementwise ops
+        # at the limits of the non-broadcasted volumes
+        # Non-broadcasted ops aren't affected by the wrapping
+        if op_type is not None and op_type.is_binary_elementwise_op():
+            tmp = list(ifm_shape)
+            one = Shape4D(1, 1, 1, 1)
+            new_start_coord = Box.wrap(new_start_coord, tmp)
+            new_end_coord = Box.wrap(Shape4D(list(new_end_coord)) - one, tmp) + one
+
         return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
 
     def make_weight_box(weight_shape, npu_block_type, oc_range_start=None, oc_range_end=None, weights_transposed=False):
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 81c0d5b..9506808 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -134,7 +134,6 @@
         ifm_present = Box([0, 0, 0, 0], [0, 0, 0, 0])
         producer_op = sched_op.ifm.connection.producers[0]
         prev_cmd_gen = generate_high_level_commands_for_sched_op(producer_op, schedule)
-
     ofm_step = op_info.stripe
     for start_height in range(ofm_start.height, ofm_end.height, ofm_step.height):
         end_height = min(start_height + ofm_step.height, ofm_end.height)
@@ -152,7 +151,6 @@
                 ofm_box = Box(ofm_box_start.as_list(), ofm_box_end.as_list())
                 ifm_box = Box([], [])
                 ifm2_box = Box([], [])
-
                 # Calculate IFM input box based on the OFM box
                 if ifm:
                     ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
@@ -165,8 +163,8 @@
                         read_offsets[0],
                         read_shapes[0],
                         upscaling,
+                        op.type,
                     )
-
                 # Calculate IFM2 input box based on the OFM box
                 if ifm2:
                     ifm2_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
@@ -179,6 +177,7 @@
                         read_offsets[1],
                         read_shapes[1],
                         upscaling,
+                        op.type,
                     )
 
                 ifm_required = ifm_box
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 0e17d70..7e989a7 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -183,6 +183,7 @@
         self.uses_scalar = ps.primary_op.ifm2 is not None and (
             ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
         )
+
         self.ifm_ublock = arch.ifm_ublock
 
         self.ifm = SchedulerTensor(
@@ -220,6 +221,31 @@
 
         self.index = 0
 
+        # Perform an IFM swap for certain binary elementwise operators
+        # in order to enable cascading, if the SchedOp conforms to
+        # Elementwise cascading rules.
+        if self.op_type.is_binary_elementwise_op() and CascadeBuilder.element_wise_cascading_conformity(self):
+            ifm1 = ps.ifm_tensor
+            ifm2 = ps.ifm2_tensor
+            ofm = ps.ofm_tensor
+            assert ifm1.elements() > 0
+            assert ifm2.elements() > 0
+
+            if (
+                # The non-constant IFM should be the primary input
+                (ifm1.ops[0].type == Op.Const and ifm2.ops[0].type != Op.Const)
+                # The non-broadcasted IFM should be the primary input
+                or (ifm1.shape != ofm.shape and ifm2.shape == ofm.shape)
+            ):
+                self.ifm, self.ifm2 = self.ifm2, self.ifm
+
+                self.parent_ps.ifm_shapes = self.parent_ps.ifm_shapes[::-1]
+                self.parent_ps.inputs = self.parent_ps.inputs[::-1]
+                self.parent_ps.ifm_tensor, self.parent_ps.ifm2_tensor = (
+                    self.parent_ps.ifm2_tensor,
+                    self.parent_ps.ifm_tensor,
+                )
+
     def add_ifm_connection(self, conn: "Connection"):
         """Add input connection to another SchedulerOperation or Subgraph Input"""
         conn.consumers.append(self)