MLBEDSW-6343: Remove op_index constraint

Remove op_index constraint and force linear format for all Conv2D that
have strides that can be optimised.

Change-Id: Idef3508ab074ea9abeacac030eaaa15a00ad1211
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 44f5d6a..518b6db 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -942,19 +942,21 @@
     return op
 
 
-def fixup_strided_conv(op, arch, nng):
+def fixup_strided_conv(op: Operation, arch, nng):
+    """Optimize or fixup strided Conv2DBias
+    Optimization:
+        Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping
+        both IFM and filter.
+
+    Fixup:
+        Introduce software support for Conv2DBias with stride_width = 4 by
+        reducing it to 1 when possible by re-shaping both IFM and filter.
+    """
     if op.type != Op.Conv2DBias:
         return op
     stride_x, stride_y = op.get_kernel_stride()
     weight_tensor = op.weights
     ifm_shape = op.ifm_shapes[0]
-
-    # Do not optimize if op is not the first in the network and stride is
-    # supported by the hardware
-    if op.op_index != 0 and stride_x < 4:
-        return op
-    op.ifm.needs_linear_format = True
-
     if (
         (stride_x == 2 or stride_x == 4)
         and ifm_shape.depth <= 4
@@ -1004,6 +1006,7 @@
         stride_x = 1
         op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
 
+        op.ifm.force_linear_format = True
     return op
 
 
@@ -2125,7 +2128,6 @@
         convert_prelu,
         convert_mul_max_to_abs_or_lrelu,
         convert_lrelu,
-        fixup_strided_conv,
         convert_hardswish_to_lut,
         rewrite_fully_connected_input,
         convert_batched_fc_shape,
@@ -2139,6 +2141,7 @@
         convert_tanh_sigmoid_to_lut,
         replace_pad_by_hw_pad,
         fixup_dilation_gt2,
+        fixup_strided_conv,
     ]
 
     for idx, sg in enumerate(nng.subgraphs):