MLBEDSW-3035: Updated StridedSlice checks

Updated supported operator checks for StridedSlice:
- allow negative indices in begin/end values
- added more checks on shapes

Change-Id: I3ac76bfa6b313f0e2250f0749f152fb0e3aa033c
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 63eb01b..9e9da8c 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -19,6 +19,11 @@
 
 from .data_type import BaseType
 from .data_type import DataType
+from .operation import get_slice_offsets
+
+
+def warn_cpu(op, msg):
+    print("Warning: {} {}, placing on CPU".format(op.type, msg))
 
 
 class SupportedOperators:
@@ -381,17 +386,45 @@
 
     def check_memory_only_restrictions(self, op):
         if op.type == "StridedSlice":
-            # check stride size
-            if len(op.inputs) > 3 and any(stride != 1 for stride in op.inputs[3].values):
+            if len(op.inputs) != 4:
+                warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
                 return False
-            # check "end - begin" doesnt result in any zero or negative elements
-            if any((end - begin) <= 0 for begin, end in zip(op.inputs[1].values, op.inputs[2].values)):
+            input_tens, begin_tens, end_tens, strides_tens = op.inputs
+            if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
+                warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
+                return False
+            if not (
+                len(input_tens.shape)
+                == len(op.outputs[0].shape)
+                == len(begin_tens.values)
+                == len(end_tens.values)
+                == len(strides_tens.values)
+            ):
+                warn_cpu(op, "has input tensors with shapes that are not supported")
+                return False
+            # check stride size
+            if any(stride != 1 for stride in strides_tens.values):
+                warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
                 return False
             # check ellipsis_mask
             if op.attrs["ellipsis_mask"] != 0:
+                warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
                 return False
             # check if both new_axis_mask and shrink_axis_mask have bit set
             if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
+                warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
+                return False
+            # Calculate offset start/end
+            offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
+            offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
+            # check "end - begin" doesn't result in any zero or negative elements
+            if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
+                warn_cpu(
+                    op,
+                    "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
+                        begin_tens.values, end_tens.values
+                    ),
+                )
                 return False
         if op.type == "SplitV":
             # check that maximum one size is set to -1, indicating that size should be inferred