blob: df3104341ec5aeb40c703797698ee8ff09810fd1 [file] [log] [blame]
Louis Verhaardfa2f92a2020-09-21 11:56:18 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for support_operators
19from ethosu.vela.data_type import DataType
20from ethosu.vela.supported_operators import SupportedOperators
21from ethosu.vela.tensor import create_const_tensor
22from ethosu.vela.tensor import Tensor
23from ethosu.vela.test import testutil
24
25support = SupportedOperators()
26
27
28def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
29 in0 = Tensor(in_shape, DataType.uint8, "in")
30 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets)
31 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets)
32 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1])
33 out = Tensor(out_shape, DataType.uint8, "out")
34 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
35 return testutil.create_op("StridedSlice", [in0, in1, in2, in3], out, attrs=attrs)
36
37
38def create_strided_slice():
39 # Creates a valid strided slice operator with some valid inputs/outputs
40 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
41 op.attrs["begin_mask"] = 1
42 op.attrs["end_mask"] = 9
43 assert support.is_operator_supported(op)
44 return op
45
46
47def test_strided_slice():
48 # Tests support for StridedSlice operator
49 op = create_strided_slice()
50 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
51 op.attrs["new_axis_mask"] = 2
52 assert support.is_operator_supported(op)
53 op = create_strided_slice()
54 op.attrs["shrink_axis_mask"] = 3
55 assert support.is_operator_supported(op)
56 # But setting both to non-zero is not supported
57 op.attrs["new_axis_mask"] = 2
58 assert not support.is_operator_supported(op)
59 # begin values must not be None
60 op.inputs[1].values = None
61 assert not support.is_operator_supported(op)
62 # Unsupported strides
63 op = create_strided_slice()
64 op.inputs[3].values = [1, 1, 2, 1]
65 assert not support.is_operator_supported(op)
66 # Wrong number of input tensors
67 op = create_strided_slice()
68 op.add_input_tensor(op.inputs[0].clone())
69 assert not support.is_operator_supported(op)
70 # Unsupported ellipsis mask
71 op = create_strided_slice()
72 op.attrs["ellipsis_mask"] = 1
73 assert not support.is_operator_supported(op)
74 # Examples where end offset <= begin offset
75 op = create_strided_slice()
76 op.inputs[1].values = [0, 7, 2, 0]
77 assert not support.is_operator_supported(op)
78 op = create_strided_slice()
79 op.inputs[2].values = [0, 7, 2, 0]
80 assert not support.is_operator_supported(op)
81 op = create_strided_slice()
82 op.attrs["begin_mask"] = 0
83 assert not support.is_operator_supported(op)
84 op = create_strided_slice()
85 op.attrs["end_mask"] = 0
86 assert not support.is_operator_supported(op)