MLBEDSW-8561: Striding support in H/W for StridedSlice

Change-Id: Ie6f39d9c4125f7c16d27621de47cd76143c2e636
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 3af8588..ccbb1f2 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -141,7 +141,7 @@
         if not split_op.run_on_npu:
             return tens
 
-        inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
+        inp, outputs, axis, offset_start, offset_end, strides_tens = split_op.get_split_inputs_axis()
 
         tens.ops = []
         new_op = Operation(Op.SplitSliceRead, split_op.name)
@@ -150,8 +150,10 @@
         if None in (offset_end, offset_start):
             read_shape = None
         else:
-            # the read shape is relative to each start offset
-            read_shape = Shape4D([oe - os for oe, os in zip(offset_end, offset_start)])
+            # The read shape is relative to each start offset
+            # Limit read shape to the size of the IFM - offset is not necessarily limited
+            ifm_dims = split_op.ifm_shapes[0].as_list()
+            read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)])
 
         # For Split the offset cannot be extracted from the tensor so it has to
         # be calculated from the index of the output tensor
@@ -182,6 +184,9 @@
         new_op.set_output_tensor(tens)
         new_op.ifm_shapes.append(Shape4D(inp.shape))
         new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
+        # Set stride multiplier in H/W if a stride tensor is provided
+        s_h, s_w = (strides_tens.values[-3], strides_tens.values[-2]) if strides_tens else (1, 1)
+        new_op.ifm_stride_multiplier[0] = [1, s_h, s_w]  # C/H/W
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
@@ -193,18 +198,24 @@
         # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
         # or if an avgpool need to be inserted
         # Not possible to move:
+        #   - if ifm stride multiplier is larger than one in any dimension
         #   - if consumer is a Transpose op since ifm shape has been reshaped and can not be changed
         #   - if consumer is elementwise and ifm needs to be broadcasted
-        if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
-            consumer is not None
-            and consumer.run_on_npu
-            and consumer.type not in memory_only_ops
-            and consumer.original_type != Op.Transpose
-            and check_splitsliceread_to_consumer_shape(op, consumer)
-            and not (
-                consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+        if (
+            op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+            and all(s_mul == 1 for s_mul in op.ifm_stride_multiplier[0])
+            and all(
+                consumer is not None
+                and consumer.run_on_npu
+                and consumer.type not in memory_only_ops
+                and consumer.original_type != Op.Transpose
+                and check_splitsliceread_to_consumer_shape(op, consumer)
+                and not (
+                    consumer.type.is_binary_elementwise_op()
+                    and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+                )
+                for consumer in op.ofm.consumer_list
             )
-            for consumer in op.ofm.consumer_list
         ):
             # SplitSliceRead can be performed by tensor consumer(s)
             for cons_op in list(op.ofm.consumer_list):
@@ -219,6 +230,9 @@
             avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
             avgpool_op.read_offsets[0] = op.read_offsets[0]
             avgpool_op.read_shapes[0] = op.read_shapes[0]
+            if any(s_mul != 1 for s_mul in op.ifm_stride_multiplier[0]):
+                avgpool_op.ifm_stride_multiplier[0] = op.ifm_stride_multiplier[0].copy()
+                avgpool_op.ifm.force_linear_format = True
 
             op.ifm.consumer_list.remove(op)
             DebugDatabase.add_optimised(op, avgpool_op)