MLBEDSW-3377: fixup_stridedslice_output may silently change CPU ops

This commit removes the constraint on all tensor
shapes matching the OFM shape.
The motivation is that this constraint essentially
only checks that the fixup function has run.
This means that it removes the possibility for the
fixup function to run after the supported operator
check and this effectively means that any
StridedSlice operator that would be placed on the
CPU is still modified by the fixup function.
Because the fixup function is moved to after the
supported operators check, some unreachable cases
are removed from the fixup function.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I7a82126b7de73bd67873b4e6daf53a6767e33d16
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 32f97d2..e31348b 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -422,19 +422,12 @@
 
 def fixup_stridedslice_output(tens, arch, nng):
     op = tens.ops[0]
-    if op.type == Op.StridedSlice:
+    if op.run_on_npu and op.type == Op.StridedSlice:
         reshape_input_shape = tens.shape
         new_axis_mask = op.attrs["new_axis_mask"]
         shrink_axis_mask = op.attrs["shrink_axis_mask"]
-        ellipsis_mask = op.attrs["ellipsis_mask"]
 
-        if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0:
-            # Not supported, will be put on CPU
-            return tens
-        if shrink_axis_mask == 0 and new_axis_mask == 0:
-            # Equal Rank StridedSlice, no need to insert reshape
-            return tens
-        elif shrink_axis_mask != 0:
+        if shrink_axis_mask != 0:
             n = 0
             axis = 0
             while shrink_axis_mask:
@@ -446,7 +439,6 @@
 
             assert len(tens.shape) == (len(op.inputs[0].shape) - n)
             op.attrs["shrink_axis_mask"] = 0
-
         elif new_axis_mask != 0:
             n = 0
             axis = 0
@@ -1092,7 +1084,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False,
+            nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
         )
 
     for idx, sg in enumerate(nng.subgraphs):
@@ -1113,7 +1105,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # combined rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], []
+            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], []
         )
 
     if verbose_graph:
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 04cda1d..3e649e0 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -219,7 +219,6 @@
         # StridedSlice specific checks:
         self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count)
         self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const)
-        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches)
         self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values)
         self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask)
         self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks)
@@ -728,22 +727,6 @@
         return valid, f"Op has non-constant tensors: {extra}"
 
     @staticmethod
-    def constraint_stridedslice_tens_size_matches(op):
-        "All Input sizes must match OFM size"
-        ifm, begin, end, strides = op.inputs
-        ifm_size = len(ifm.shape)
-        ofm_size = len(op.ofm.shape)
-        begin_size = len(begin.values)
-        end_size = len(end.values)
-        strides_size = len(strides.values)
-        valid = ifm_size == ofm_size == begin_size == end_size == strides_size
-        extra = (
-            f"Op has ofm_size={ofm_size}, ifm_size={ifm_size},"
-            f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}"
-        )
-        return valid, extra
-
-    @staticmethod
     def constraint_stridedslice_stride_values(op):
         "All Strides values must be 1"
         strides = op.inputs[3]
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 595ea59..245ebcf 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -486,12 +486,6 @@
     assert not support.is_operator_supported(op)
 
 
-def test_constraint_stridedslice_tens_size_matches():
-    op = create_strided_slice()
-    op.inputs[1].values = [1, 1, 1, 1, 1, 1, 1, 1]
-    assert not support.is_operator_supported(op)
-
-
 def test_constraint_stridedslice_stride_values():
     # Unsupported strides
     op = create_strided_slice()