MLBEDSW-3035: Updated StridedSlice checks

Updated supported operator checks for StridedSlice:
- allow negative indices in begin/end values
- added more checks on shapes

Change-Id: I3ac76bfa6b313f0e2250f0749f152fb0e3aa033c
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6bc5a32..252f03b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -43,6 +43,19 @@
     return op
 
 
+def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
+    # For strided slice operator: get start or end offsets
+    offsets = len(input_shape) * [0] if is_begin else input_shape[:]
+    for idx in range(len(input_shape)):
+        # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
+        if (offset_mask & (1 << idx)) == 0:
+            offsets[idx] = offset_tens.values[idx]
+            if offsets[idx] < 0:
+                # Convert offset to positive value
+                offsets[idx] += input_shape[idx]
+    return offsets
+
+
 class Operation:
     """Class representing a Neural Network operation. Has a name, a type,
 input and output tensors, as well as an attribute dictionary."""
@@ -309,8 +322,6 @@
             input_tens, begin_tens, end_tens, strides_tens = self.inputs
             outputs = self.outputs
             out_tens = outputs[0]
-            offset_start = [0] * len(outputs[0].shape)
-            offset_end = [0] * len(outputs[0].shape)
 
             # Extract masks
             begin_mask = self.attrs["begin_mask"]
@@ -323,20 +334,8 @@
             # may have the attribute modified and handled in the graph optimization phase.
             assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
             assert len(input_tens.shape) == len(out_tens.shape)
-
-            for idx in range(len(input_tens.shape)):
-                # Check if slicing is needed in this axis
-                if end_tens.values[idx] != input_tens.shape[idx] or (
-                    end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
-                ):
-                    # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
-                    if (begin_mask & (1 << idx)) == 0:
-                        offset_start[idx] = begin_tens.values[idx]
-
-                    # If the i:th bit in end_mask is set then the value on end[i] should be ignored
-                    if (end_mask & (1 << idx)) == 0:
-                        offset_end[idx] = end_tens.values[idx]
-
+            offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
+            offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
         elif self.type == "UnpackReshaped":
             # Requires fixup_unpack_output to be called before this point
             input_tens = self.inputs[0]
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 63eb01b..9e9da8c 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -19,6 +19,11 @@
 
 from .data_type import BaseType
 from .data_type import DataType
+from .operation import get_slice_offsets
+
+
+def warn_cpu(op, msg):
+    print("Warning: {} {}, placing on CPU".format(op.type, msg))
 
 
 class SupportedOperators:
@@ -381,17 +386,45 @@
 
     def check_memory_only_restrictions(self, op):
         if op.type == "StridedSlice":
-            # check stride size
-            if len(op.inputs) > 3 and any(stride != 1 for stride in op.inputs[3].values):
+            if len(op.inputs) != 4:
+                warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
                 return False
-            # check "end - begin" doesnt result in any zero or negative elements
-            if any((end - begin) <= 0 for begin, end in zip(op.inputs[1].values, op.inputs[2].values)):
+            input_tens, begin_tens, end_tens, strides_tens = op.inputs
+            if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
+                warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
+                return False
+            if not (
+                len(input_tens.shape)
+                == len(op.outputs[0].shape)
+                == len(begin_tens.values)
+                == len(end_tens.values)
+                == len(strides_tens.values)
+            ):
+                warn_cpu(op, "has input tensors with shapes that are not supported")
+                return False
+            # check stride size
+            if any(stride != 1 for stride in strides_tens.values):
+                warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
                 return False
             # check ellipsis_mask
             if op.attrs["ellipsis_mask"] != 0:
+                warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
                 return False
             # check if both new_axis_mask and shrink_axis_mask have bit set
             if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
+                warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
+                return False
+            # Calculate offset start/end
+            offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
+            offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
+            # check "end - begin" doesn't result in any zero or negative elements
+            if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
+                warn_cpu(
+                    op,
+                    "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
+                        begin_tens.values, end_tens.values
+                    ),
+                )
                 return False
         if op.type == "SplitV":
             # check that maximum one size is set to -1, indicating that size should be inferred
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
new file mode 100644
index 0000000..df31043
--- /dev/null
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -0,0 +1,86 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for support_operators
+from ethosu.vela.data_type import DataType
+from ethosu.vela.supported_operators import SupportedOperators
+from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import Tensor
+from ethosu.vela.test import testutil
+
+support = SupportedOperators()
+
+
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+    in0 = Tensor(in_shape, DataType.uint8, "in")
+    in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets)
+    in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets)
+    in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1])
+    out = Tensor(out_shape, DataType.uint8, "out")
+    attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+    return testutil.create_op("StridedSlice", [in0, in1, in2, in3], out, attrs=attrs)
+
+
+def create_strided_slice():
+    # Creates a valid strided slice operator with some valid inputs/outputs
+    op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
+    op.attrs["begin_mask"] = 1
+    op.attrs["end_mask"] = 9
+    assert support.is_operator_supported(op)
+    return op
+
+
+def test_strided_slice():
+    # Tests support for StridedSlice operator
+    op = create_strided_slice()
+    # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
+    op.attrs["new_axis_mask"] = 2
+    assert support.is_operator_supported(op)
+    op = create_strided_slice()
+    op.attrs["shrink_axis_mask"] = 3
+    assert support.is_operator_supported(op)
+    # But setting both to non-zero is not supported
+    op.attrs["new_axis_mask"] = 2
+    assert not support.is_operator_supported(op)
+    # begin values must not be None
+    op.inputs[1].values = None
+    assert not support.is_operator_supported(op)
+    # Unsupported strides
+    op = create_strided_slice()
+    op.inputs[3].values = [1, 1, 2, 1]
+    assert not support.is_operator_supported(op)
+    # Wrong number of input tensors
+    op = create_strided_slice()
+    op.add_input_tensor(op.inputs[0].clone())
+    assert not support.is_operator_supported(op)
+    # Unsupported ellipsis mask
+    op = create_strided_slice()
+    op.attrs["ellipsis_mask"] = 1
+    assert not support.is_operator_supported(op)
+    # Examples where end offset <= begin offset
+    op = create_strided_slice()
+    op.inputs[1].values = [0, 7, 2, 0]
+    assert not support.is_operator_supported(op)
+    op = create_strided_slice()
+    op.inputs[2].values = [0, 7, 2, 0]
+    assert not support.is_operator_supported(op)
+    op = create_strided_slice()
+    op.attrs["begin_mask"] = 0
+    assert not support.is_operator_supported(op)
+    op = create_strided_slice()
+    op.attrs["end_mask"] = 0
+    assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index d4ae97b..13b6bf4 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -56,6 +56,14 @@
     return op
 
 
+def create_op(op_type, inputs, output, attrs=dict()):
+    op = Operation(op_type, output.name + "_op")
+    op.inputs = inputs
+    op.outputs = [output]
+    op.attrs = attrs
+    return op
+
+
 def create_subgraph(op_list):
     # Creates subgraph using the given list of operations
     sg = Subgraph()