blob: 355b472c195c799952bb601b365a529034b1388c [file] [log] [blame]
Louis Verhaardebf4af62021-01-27 15:57:57 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaardfa2f92a2020-09-21 11:56:18 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for support_operators
Michael McGeagh37ded342020-10-01 15:37:44 +010019import numpy as np
20
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020021from ethosu.vela.data_type import DataType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from ethosu.vela.operation import ActivationFunction
Louis Verhaardaee5d752020-09-30 09:01:52 +020023from ethosu.vela.operation import Op
Michael McGeagh16895482020-12-14 15:51:20 +000024from ethosu.vela.operation import Padding
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025from ethosu.vela.supported_operators import SupportedOperators
26from ethosu.vela.tensor import create_const_tensor
Michael McGeagh37ded342020-10-01 15:37:44 +010027from ethosu.vela.tensor import QuantizationParameters
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020028from ethosu.vela.tensor import Tensor
29from ethosu.vela.test import testutil
30
31support = SupportedOperators()
32
33
Michael McGeagh65fd9982020-10-20 11:49:28 +010034def test_constraint_tens_no_dynamic():
35 # Tensors cannot be dynamic (no shape, not a scalar)
36 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020037 assert not support.is_operator_supported(op)
Michael McGeagh37ded342020-10-01 15:37:44 +010038
39
40def test_constraint_tens_defined_shape():
41 # Tensors cannot have None in them
Michael McGeagh1f951fc2020-10-14 09:30:02 +010042 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
Michael McGeagh37ded342020-10-01 15:37:44 +010043 assert not support.is_operator_supported(op)
44
45
Michael McGeagh65fd9982020-10-20 11:49:28 +010046def test_constraint_tens_output_scalar():
47 # Scalar output is not allowed at all:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010048 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
Michael McGeagh65fd9982020-10-20 11:49:28 +010049 op.ofm.values = 0.5
Michael McGeagh37ded342020-10-01 15:37:44 +010050 assert not support.is_operator_supported(op)
Michael McGeagh184b2502020-10-09 17:19:52 +010051
52
Michael McGeagh65fd9982020-10-20 11:49:28 +010053def test_constraint_tens_input_scalar():
Michael McGeagh184b2502020-10-09 17:19:52 +010054 # Shapeless input is allowed if its of a certain type:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010055 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
Michael McGeagh184b2502020-10-09 17:19:52 +010056 assert support.is_operator_supported(op)
Michael McGeagh37ded342020-10-01 15:37:44 +010057 # Invalid shapeless input due to op type:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010058 op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
Michael McGeagh65fd9982020-10-20 11:49:28 +010059 op.ifm.values = 0.5
Michael McGeagh37ded342020-10-01 15:37:44 +010060 assert not support.is_operator_supported(op)
61
62
63def test_constraint_tens_shape_size():
64 # Tensors cannot be > 4D
patrik.gustavssoneeb85152020-12-21 17:10:40 +000065 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
Michael McGeagh37ded342020-10-01 15:37:44 +010066 assert not support.is_operator_supported(op)
67
68
69def test_constraint_tens_dtype():
Michael McGeagh184b2502020-10-09 17:19:52 +010070 # Tensors can only be of type uint8, int8, int16 and int32
Michael McGeagh1f951fc2020-10-14 09:30:02 +010071 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.float32)
Michael McGeagh37ded342020-10-01 15:37:44 +010072 assert not support.is_operator_supported(op)
Michael McGeagh184b2502020-10-09 17:19:52 +010073
74
75def test_constraint_tens_int32_ops():
Michael McGeagh37ded342020-10-01 15:37:44 +010076 # For int32, only select op types are allowed:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010077 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], datatype=DataType.int32)
Michael McGeagh37ded342020-10-01 15:37:44 +010078 assert support.is_operator_supported(op)
Michael McGeagh1f951fc2020-10-14 09:30:02 +010079 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
Michael McGeagh37ded342020-10-01 15:37:44 +010080 assert not support.is_operator_supported(op)
81
82
83def test_constraint_tens_dimension():
84 # Tensors can only have values in the inclusive range of 1-65535
Michael McGeagh1f951fc2020-10-14 09:30:02 +010085 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 0], [1, 8, 8, 65536])
Michael McGeagh37ded342020-10-01 15:37:44 +010086 assert not support.is_operator_supported(op)
87
88
Michael McGeagh184b2502020-10-09 17:19:52 +010089def test_constraint_tens_quant_none_check():
90 # Tensors must have quantization parameters
Michael McGeagh1f951fc2020-10-14 09:30:02 +010091 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
Michael McGeagh184b2502020-10-09 17:19:52 +010092 assert not support.is_operator_supported(op)
93
94
95def test_constraint_tens_quant_scale():
Louis Verhaard9a0cff12021-01-08 11:17:33 +010096 # Quantization scale cannot be infinite
Michael McGeagh184b2502020-10-09 17:19:52 +010097 qp = QuantizationParameters()
Michael McGeagh65fd9982020-10-20 11:49:28 +010098 qp.zero_point = 0
Michael McGeagh184b2502020-10-09 17:19:52 +010099 qp.scale_f32 = np.inf
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100100 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
Michael McGeagh184b2502020-10-09 17:19:52 +0100101 assert not support.is_operator_supported(op)
102
103
Dwight Lidmanc7187432020-11-16 17:40:46 +0100104def test_constraint_tens_quant_per_axis_not_supp():
105 # Quantization scale cannot be array-valued for elemwise ops
106 qp = QuantizationParameters()
107 qp.zero_point = np.zeros((1, 3))
108 qp.scale_f32 = np.ones((1, 3))
109 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
110 assert not support.is_operator_supported(op)
111
112
113def test_constraint_tens_quant_per_axis_is_supp():
114 op = testutil.create_op_with_quant_tensors(
115 Op.Conv2DBias, [1, 1, 1, 3], [1, 1, 1, 3], weights_shape=[1, 1, 1, 3], bias_shape=[1, 1, 1, 3]
116 )
117 op.attrs = {"stride_w": 1, "stride_h": 1}
118 assert support.is_operator_supported(op)
119 qp = QuantizationParameters()
120 qp.zero_point = np.zeros((1, 3))
121 qp.scale_f32 = np.ones((1, 3))
122 op.bias.quantization = qp
123 assert support.is_operator_supported(op)
124
125
Dwight Lidman0dd21c72020-11-24 13:45:50 +0100126def test_constraint_fc_output_2d_not_supp():
127 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
128 assert not support.is_operator_supported(op)
129 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
130 assert not support.is_operator_supported(op)
131 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
132 assert not support.is_operator_supported(op)
133
134
135def test_constraint_fc_output_2d_is_supp():
136 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
137 assert support.is_operator_supported(op)
138 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
139 assert support.is_operator_supported(op)
140
141
Michael McGeagh37ded342020-10-01 15:37:44 +0100142def test_constraint_faf():
143 # Fused activation functions, if set, must be a valid op type
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100144 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 op.activation = ActivationFunction(Op.Conv2D)
Michael McGeagh37ded342020-10-01 15:37:44 +0100146 assert not support.is_operator_supported(op)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100147
148
Louis Verhaardc7761512021-02-03 10:22:38 +0100149def test_constraint_faf_ofm_dtype():
150 # If fused activation function is present, OFM must be 8 or 16 bit
151 shp = [1, 8, 8, 8]
152 for dtype in [DataType.int8, DataType.uint8, DataType.int16, DataType.int32]:
153 op = testutil.create_elemwise_op(Op.Add, "op", shp, shp, shp, datatype=dtype)
154 op.activation = ActivationFunction(Op.Relu)
155 expected = dtype.size_in_bytes() <= 2
156 assert support.is_operator_supported(op) == expected, f"Data type: {dtype}"
157
158
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100159def test_constraint_conv_pass():
160 # First test a simple conv passes
161 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
162 op.attrs = {"stride_w": 1, "stride_h": 1}
163 assert support.is_operator_supported(op)
164
165
166def test_constraint_stride_type():
167 # Stride width and height must be integer types
168 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
169 op.attrs = {"stride_w": 1.5, "stride_h": "1"}
170 assert not support.is_operator_supported(op)
171
172
173def test_constraint_stride_range():
174 # Stride width and height must lie within a certain range
175 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
176 op.attrs = {"stride_w": 0, "stride_h": 20}
177 assert not support.is_operator_supported(op)
178
179
180def test_constraint_dilation_type():
181 # Dilation width and height must be integer types
182 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
183 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
184 assert not support.is_operator_supported(op)
185
186
187def test_constraint_dilation_range():
188 # Dilation width and height must lie within a certain range
189 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
190 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20}
191 assert not support.is_operator_supported(op)
192
193
194def test_constraint_dilated_height_range():
195 # Dilated kernel height must lie within a certain range
196 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
197 op.attrs = {"stride_w": 1, "stride_h": 1}
198 assert not support.is_operator_supported(op)
199
200
201def test_constraint_dilated_product_range():
202 # Dilated kernel width x height must lie within a certain range
203 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
204 op.attrs = {"stride_w": 1, "stride_h": 1}
205 assert not support.is_operator_supported(op)
206
207
208def test_constraint_weights_type():
209 # Weight tensor must be 8-bit
210 op = testutil.create_op_with_quant_tensors(
211 Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
212 )
213 op.attrs = {"stride_w": 1, "stride_h": 1}
214 assert not support.is_operator_supported(op)
215
216
Michael McGeagh65fd9982020-10-20 11:49:28 +0100217def test_constraint_weights_const():
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100218 # Weight tensor cannot be non-const tensors
219 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
220 op.attrs = {"stride_w": 1, "stride_h": 1}
221 weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
Michael McGeagh65fd9982020-10-20 11:49:28 +0100222 weights.quantization = testutil.default_quant_params()
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100223 op.add_input_tensor(weights)
224 assert not support.is_operator_supported(op)
225
226
227def test_constraint_weights_limit():
228 # Sum of weights has a limit
229 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
230 op.attrs = {"stride_w": 1, "stride_h": 1}
231 op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
232 assert not support.is_operator_supported(op)
233
234
235def test_constraint_bias_type():
236 # Bias must have a certain datatype
237 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
238 op.attrs = {"stride_w": 1, "stride_h": 1}
239 bias = Tensor([1, 8, 8, 8], DataType.uint8, "bias")
240 op.add_input_tensor(bias)
241 assert not support.is_operator_supported(op)
242
243
244def test_constraint_bias_40bit():
245 # Bias must not exceed 40-bit
246 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
247 op.attrs = {"stride_w": 1, "stride_h": 1}
248 bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
Michael McGeagh65fd9982020-10-20 11:49:28 +0100249 bias.quant_values = np.array([0x01FF_FFFF_FFFF])
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100250 op.add_input_tensor(bias)
251 assert not support.is_operator_supported(op)
252
253
254def test_constraint_batch_size():
255 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
256 op.attrs = {"stride_w": 1, "stride_h": 1}
257 assert not support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100258
259
260def test_constraint_quant_scale_inf():
Louis Verhaard9a0cff12021-01-08 11:17:33 +0100261 # Test handling IFM scale/OFM scale is infinite
Michael McGeagh65fd9982020-10-20 11:49:28 +0100262 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
Louis Verhaard9a0cff12021-01-08 11:17:33 +0100263 op.ifm.quantization.scale_f32 = np.float32(1e9)
264 op.ofm.quantization.scale_f32 = np.float32(1e-35)
265 assert not support.is_operator_supported(op)
266
267
268def test_constraint_ofm_scale_too_small():
269 # Tests handling of OFM scale < 1e-38
270 shp = [1, 10, 20, 16]
271 op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
272 assert support.is_operator_supported(op)
273 op.ofm.quantization.scale_f32 = 1e-43
Michael McGeagh65fd9982020-10-20 11:49:28 +0100274 assert not support.is_operator_supported(op)
275
276
277def test_constraint_depth_multiplier():
278 # Valid. Depth multiplier is 1 so no further constraints
279 op = testutil.create_op_with_quant_tensors(
280 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
281 )
282 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 1}
283 assert support.is_operator_supported(op)
284 # Invalid. Depth multiplier doesnt equal ofm channel
285 op = testutil.create_op_with_quant_tensors(
286 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]
287 )
288 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
289 assert not support.is_operator_supported(op)
290 # Valid. Depth multiplier is equal to ofm channel
291 op = testutil.create_op_with_quant_tensors(
292 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
293 )
294 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
295 assert support.is_operator_supported(op)
296
297
298def test_constraint_tconv_stride():
299 # Strides must be 2
300 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000301 op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100302 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
303 ifm.quantization = testutil.default_quant_params()
304 op.add_input_tensor(ifm)
305 assert not support.is_operator_supported(op)
306
307
308def test_constraint_tconv_same():
309 # Valid
310 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000311 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100312 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
313 ifm.quantization = testutil.default_quant_params()
314 op.add_input_tensor(ifm)
315 assert support.is_operator_supported(op)
316 # Invalid
317 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000318 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100319 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
320 ifm.quantization = testutil.default_quant_params()
321 op.add_input_tensor(ifm)
322 assert not support.is_operator_supported(op)
323
324
325def test_constraint_tconv_valid():
326 # Valid
327 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000328 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100329 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
330 ifm.quantization = testutil.default_quant_params()
331 op.add_input_tensor(ifm)
332 assert support.is_operator_supported(op)
333 # Invalid
334 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000335 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100336 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
337 ifm.quantization = testutil.default_quant_params()
338 op.add_input_tensor(ifm)
339 assert not support.is_operator_supported(op)
340
341
342def test_constraint_matching_in_out_types():
343 # Valid
344 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000345 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100346 assert support.is_operator_supported(op)
347 # Invalid. datatypes for ifm and ofm must match (default uint8)
348 op.ifm.dtype = DataType.int8
349 assert not support.is_operator_supported(op)
350
351
352def test_constraint_filter_type():
353 # Filter width/height must be integers
354 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000355 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100356 assert not support.is_operator_supported(op)
357
358
359def test_constraint_filter_range():
360 # Avg pool restrictions are dependent on padding:
361 # SAME padding restricts both W and H to max 8
362 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000363 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100364 assert not support.is_operator_supported(op)
365 # VALID padding limits are much larger
Michael McGeagh16895482020-12-14 15:51:20 +0000366 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100367 assert support.is_operator_supported(op)
368
369
370def test_constraint_filter_height_range_valid_pad():
371 # Avg pool restrictions are dependent on padding:
372 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000373 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100374 assert support.is_operator_supported(op)
375 # VALID padding restricts to 256 in filter height
376 op.attrs["filter_height"] = 257
377 assert not support.is_operator_supported(op)
378
379
380def test_constraint_filter_product_height_range_valid_pad():
381 # Avg pool restrictions are dependent on padding:
382 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000383 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100384 assert support.is_operator_supported(op)
385 # VALID padding restricts filter W x H to 256x256
386 op.attrs["filter_width"] = 257
387 assert not support.is_operator_supported(op)
388
389
390def test_constraint_filter_height_range():
391 # Max pool restrictions arent dependent on padding
392 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000393 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100394 assert support.is_operator_supported(op)
395 # Restricts to 256 in filter height
396 op.attrs["filter_height"] = 257
397 assert not support.is_operator_supported(op)
398 # Doesnt matter if SAME or VALID
Michael McGeagh16895482020-12-14 15:51:20 +0000399 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100400 assert not support.is_operator_supported(op)
401
402
403def test_constraint_filter_product_height_range():
404 # Max pool restrictions arent dependent on padding
405 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000406 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100407 assert support.is_operator_supported(op)
408 # Restricts filter W x H to 256x256
409 op.attrs["filter_width"] = 257
410 assert not support.is_operator_supported(op)
411 # Doesnt matter if SAME or VALID
Michael McGeagh16895482020-12-14 15:51:20 +0000412 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100413 assert not support.is_operator_supported(op)
414
415
416def test_constraint_resize():
417 # IFM W and H == 1
418 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
419 assert support.is_operator_supported(op)
420 # IFM == OFM
421 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
422 assert support.is_operator_supported(op)
423 # IFM x2 == OFM ; align_corners = False
424 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
425 assert support.is_operator_supported(op)
426 # IFM x2 -1 == OFM ; align_corners = True
427 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
428 op.attrs["align_corners"] = True
429 assert support.is_operator_supported(op)
430 # Invalid cases
431 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 20, 20, 8])
432 assert not support.is_operator_supported(op)
433 op.attrs["align_corners"] = True
434 assert not support.is_operator_supported(op)
435
436
437def test_constraint_matching_shapes():
438 # Softmax requires the ifm and ofm shapes to match
439 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
440 assert not support.is_operator_supported(op)
441 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
442 assert support.is_operator_supported(op)
443
444
Patrik Gustavsson2fa15882020-11-13 09:02:31 +0100445def test_constraint_beta_value_range():
446 # beta must be positive
447 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
448 op.attrs["beta"] = -1.0
449 assert not support.is_operator_supported(op)
450 op.attrs["beta"] = 0.0
451 assert support.is_operator_supported(op)
452
453
Michael McGeagh65fd9982020-10-20 11:49:28 +0100454def test_constraint_splitv_inferred():
455 # SplitV requires a maximum of one inferred shape (-1)
456 qp = testutil.default_quant_params()
457 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
458 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
459 op.add_input_tensor(sizes)
460 assert not support.is_operator_supported(op)
461 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
462 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
463 op.add_input_tensor(sizes)
464 assert support.is_operator_supported(op)
465
466
467def test_constraint_concat_pass():
468 # A working concat
469 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
470 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
471 ifm2.quantization = testutil.default_quant_params()
472 op.add_input_tensor(ifm2)
473 op.attrs["axis"] = 3
474 assert support.is_operator_supported(op)
475
476
477def test_constraint_axis_exists():
478 # Missing axis attribute
479 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
480 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
481 ifm2.quantization = testutil.default_quant_params()
482 op.add_input_tensor(ifm2)
483 assert not support.is_operator_supported(op)
484
485
486def test_constraint_axis_valid():
487 # Invalid axis attribute
488 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
489 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
490 ifm2.quantization = testutil.default_quant_params()
491 op.add_input_tensor(ifm2)
492 op.attrs["axis"] = 7
493 assert not support.is_operator_supported(op)
494
495
496def test_constraint_matching_dimensionality():
497 # Mismatching dimensionality: 4D+2D=4D
498 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
499 ifm2 = Tensor([1, 4], DataType.uint8, "in2")
500 ifm2.quantization = testutil.default_quant_params()
501 op.add_input_tensor(ifm2)
502 op.attrs["axis"] = 3
503 assert not support.is_operator_supported(op)
504
505
506def test_constraint_valid_dimensions():
507 # Mismatching dimension value:
508 # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
509 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
510 ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
511 ifm2.quantization = testutil.default_quant_params()
512 op.add_input_tensor(ifm2)
513 op.attrs["axis"] = 3
514 assert not support.is_operator_supported(op)
515
516
517def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
518 qp = testutil.default_quant_params()
519 in0 = Tensor(in_shape, DataType.uint8, "in")
520 in0.quantization = qp
521 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
522 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
523 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
524 out = Tensor(out_shape, DataType.uint8, "out")
525 out.quantization = qp
526 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
527 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
528
529
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100530def create_pad_op(
Louis Verhaardc822d622021-03-11 14:59:06 +0100531 in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100532):
533 qp = testutil.default_quant_params()
534 in0 = Tensor(in_shape, in_dtype, "in")
535 in0.quantization = qp
536 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
537 out = Tensor(out_shape, out_dtype, "out")
538 out.quantization = qp.clone()
539 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100540 return op
541
542
543def test_constraint_pad_input_count():
544 # Incorrect number of input tensors (2)
545 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
546 assert support.is_operator_supported(op)
547 op.add_input_tensor(op.inputs[0].clone())
548 assert not support.is_operator_supported(op)
549
550
551def test_constraint_padded_dimensions():
552 # Incorrect padding dimensions, can only pad width and height
553 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
554 assert not support.is_operator_supported(op)
Louis Verhaardc822d622021-03-11 14:59:06 +0100555 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]],)
556 assert support.is_operator_supported(op)
557 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 1]],)
558 assert not support.is_operator_supported(op)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100559
560
561def test_constraint_pad_shape():
Louis Verhaardc822d622021-03-11 14:59:06 +0100562 # PAD operator must be of shape (3,2) or (4,2)
563 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]])
564 assert support.is_operator_supported(op)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100565 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
566 assert not support.is_operator_supported(op)
567
568
569def test_constraint_pad_none():
570 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],)
571 assert not support.is_operator_supported(op)
572
573
574def test_constraint_pad_dtype():
575 # PAD operator dtype should be int32 or int64
576 op = create_pad_op(
577 in_shape=[1, 1, 1, 1],
578 out_shape=[1, 3, 3, 1],
579 padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
580 pad_dtype=DataType.int16,
581 )
582 assert not support.is_operator_supported(op)
583
584
Michael McGeagh65fd9982020-10-20 11:49:28 +0100585def create_strided_slice():
586 # Creates a valid strided slice operator with some valid inputs/outputs
587 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
588 op.attrs["begin_mask"] = 1
589 op.attrs["end_mask"] = 9
590 assert support.is_operator_supported(op)
591 return op
592
593
594def test_constraint_stridedslice_input_count():
595 # Wrong number of input tensors
596 op = create_strided_slice()
597 op.add_input_tensor(op.inputs[0].clone())
598 assert not support.is_operator_supported(op)
599
600
601def test_constraint_stridedslice_inputs_const():
602 # begin, end, stride values must not be None
603 op = create_strided_slice()
604 op.inputs[1].values = None
605 assert not support.is_operator_supported(op)
606 op = create_strided_slice()
607 op.inputs[2].values = None
608 assert not support.is_operator_supported(op)
609 op = create_strided_slice()
610 op.inputs[3].values = None
611 assert not support.is_operator_supported(op)
612
613
Michael McGeagh65fd9982020-10-20 11:49:28 +0100614def test_constraint_stridedslice_stride_values():
615 # Unsupported strides
616 op = create_strided_slice()
617 op.inputs[3].values = [1, 1, 2, 1]
618 assert not support.is_operator_supported(op)
619
620
621def test_constraint_ellipsis_mask():
622 # Unsupported ellipsis mask
623 op = create_strided_slice()
624 op.attrs["ellipsis_mask"] = 1
625 assert not support.is_operator_supported(op)
626
627
628def test_constraint_axis_masks():
629 op = create_strided_slice()
630 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
631 op.attrs["new_axis_mask"] = 2
632 assert support.is_operator_supported(op)
633 op = create_strided_slice()
634 op.attrs["shrink_axis_mask"] = 3
635 assert support.is_operator_supported(op)
636 # But setting both to non-zero is not supported
637 op.attrs["new_axis_mask"] = 2
638 assert not support.is_operator_supported(op)
639
640
641def test_constraint_slice_ranges():
642 # Examples where end offset <= begin offset
643 op = create_strided_slice()
644 op.inputs[1].values = [0, 7, 2, 0]
645 assert not support.is_operator_supported(op)
646 op = create_strided_slice()
647 op.inputs[2].values = [0, 7, 2, 0]
648 assert not support.is_operator_supported(op)
649 op = create_strided_slice()
650 op.attrs["begin_mask"] = 0
651 assert not support.is_operator_supported(op)
652 op = create_strided_slice()
653 op.attrs["end_mask"] = 0
654 assert not support.is_operator_supported(op)
655
656
657def test_constraint_matching_inputs_types():
658 # input data types must match (default is uint8)
659 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
660 op.ifm2.dtype = DataType.int8
661 assert not support.is_operator_supported(op)
662
663
664def test_constraint_matching_signed():
665 # signed inputs require output to also be signed
666 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
667 op.ofm.dtype = DataType.uint8
668 assert not support.is_operator_supported(op)
669
670
671def test_constraint_unsigned_valid():
672 # unsigned inputs require output to be either:
673 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
674 # the same (default uint8)
675 assert support.is_operator_supported(op)
676 op.ofm.dtype = DataType.int8
677 assert not support.is_operator_supported(op)
678 op.ofm.dtype = DataType.int16
679 assert not support.is_operator_supported(op)
680 # or int32
681 op.ofm.dtype = DataType.int32
682 assert support.is_operator_supported(op)
683
684
685def test_constraint_inputs_int32():
686 # both inputs must be type int32
687 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
688 assert not support.is_operator_supported(op)
689 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
690 assert support.is_operator_supported(op)
691 op.ifm2.dtype = DataType.int16
692 assert not support.is_operator_supported(op)
693
694
695def test_constraint_output_int32():
696 # output must be type int32
697 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
698 assert support.is_operator_supported(op)
699 op.ofm.dtype = DataType.int16
700 assert not support.is_operator_supported(op)
701
702
703def test_constraint_matching_quantization_parameters():
704 qp = QuantizationParameters()
705 qp.scale_f32 = np.float32(1.5)
706 qp.zero_point = 128
707 # valid - all matching (uses default quant params)
708 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
709 assert support.is_operator_supported(op)
710 # invalid - ifm mismatch ofm
711 op.ifm.quantization = qp
712 assert not support.is_operator_supported(op)
713 # invalid - ifm2 mismatch ofm
714 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
715 op.ifm2.quantization = qp
716 assert not support.is_operator_supported(op)
717 # invalid - both ifm and ifm2 mismatch ofm
718 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
719 op.ifm.quantization = qp
720 op.ifm2.quantization = qp
721 assert not support.is_operator_supported(op)
722 # valid - all matching
723 op.ofm.quantization = qp
724 assert support.is_operator_supported(op)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100725 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
726 assert support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100727
728
729def test_constraint_elemwise_batch_size():
730 # BINARY CASE
731 # Batch can be >1 if dims is <=2D
732 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
733 assert support.is_operator_supported(op)
734 # For dims >2D, batch must be 1
735 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
736 assert support.is_operator_supported(op)
737 # invalid case
738 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
739 assert not support.is_operator_supported(op)
740
741 # UNARY CASE
742 # Batch can be >1 if dims is <=2D
743 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
744 assert support.is_operator_supported(op)
745 # For dims >2D, batch must be 1
746 op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
747 assert support.is_operator_supported(op)
748 # invalid case
749 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
750 assert not support.is_operator_supported(op)
751
752
753def test_constraint_matching_either_shapes():
754 # BINARY CASE
755 # At least one ifm shape must match ofm's shape
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100756 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
Michael McGeagh65fd9982020-10-20 11:49:28 +0100757 assert support.is_operator_supported(op)
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100758 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
Michael McGeagh65fd9982020-10-20 11:49:28 +0100759 assert support.is_operator_supported(op)
760 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
761 assert not support.is_operator_supported(op)
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100762 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
763 assert not support.is_operator_supported(op)
764 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
765 assert not support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100766
767 # UNARY CASE
768 # No second input so this is treated the same as requiring ifm shape to match ofm shape
769 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
770 assert support.is_operator_supported(op)
771 op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
772 assert not support.is_operator_supported(op)
773
774
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100775def test_constraint_broadcast_shapes():
776 # BINARY CASE
777 # Only allow broadcast to 1 dim, for 1 rank index
778 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
779 assert support.is_operator_supported(op)
780 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
781 assert support.is_operator_supported(op)
782 # Only allow broadcast to 1 dim, for 3 rank indexes
783 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
784 assert support.is_operator_supported(op)
785 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
786 assert support.is_operator_supported(op)
787 # One broadcast dim not 1
788 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
789 assert not support.is_operator_supported(op)
790 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
791 assert not support.is_operator_supported(op)
792 # OFM shape dim largest ifm/ifm2 shape dim
793 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
794 assert not support.is_operator_supported(op)
795 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
796 assert not support.is_operator_supported(op)
797 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
798 assert not support.is_operator_supported(op)
799 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
800 assert not support.is_operator_supported(op)
801
802
Michael McGeagh65fd9982020-10-20 11:49:28 +0100803def test_constraint_alpha_valid():
804 # Alpha cannot be negative
805 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
806 op.attrs["alpha"] = 0
807 assert support.is_operator_supported(op)
808 op.attrs["alpha"] = -1
809 assert not support.is_operator_supported(op)
Diqing Zhong189f7482021-01-26 12:12:51 +0100810
811
812def test_constraint_hardswish_dtype():
813 # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
814 # UINT8
815 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
816 assert support.is_operator_supported(op)
817 # INT8
818 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
819 assert support.is_operator_supported(op)
820
821 # Invalid
822 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
823 assert not support.is_operator_supported(op)
824 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
825 assert not support.is_operator_supported(op)
826 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
827 assert not support.is_operator_supported(op)
828
829 in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
830 out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
831 op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
832 assert not support.is_operator_supported(op)
erik.andersson@arm.com0cbb1662021-02-22 15:47:07 +0100833
834
835def test_constraint_keep_dims_ifm_ofm():
836 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
837 op.attrs["keep_num_dims"] = True
838 assert not support.is_operator_supported(op)
839 op.attrs["keep_num_dims"] = False
840 assert support.is_operator_supported(op)
Dwight Lidman4f728c02020-12-17 15:14:45 +0100841
842
843def create_mean(input_shape, output_shape, indices, datatype, attrs):
844 ifm = Tensor(input_shape, datatype, "in")
845 ifm.quantization = testutil.default_quant_params()
846 indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
847 ofm = Tensor(output_shape, datatype, "out")
848 ofm.quantization = testutil.default_quant_params()
849 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
850 return op
851
852
853def test_mean_dtype():
854 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
855 assert support.is_operator_supported(op)
856 op.ifm.dtype = DataType.int16
857 op.ofm.dtype = DataType.int16
858 assert not support.is_operator_supported(op)
859
860
Dwight Lidman4f728c02020-12-17 15:14:45 +0100861def test_mean_axis():
862 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
863 assert not support.is_operator_supported(op)
864
865
866def test_mean_hw_product():
Dwight Lidman95b279f2021-03-26 10:53:28 +0100867 op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
Dwight Lidman4f728c02020-12-17 15:14:45 +0100868 assert support.is_operator_supported(op)
869 op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
870 assert not support.is_operator_supported(op)
871
872
873def test_mean_hw_product_int8():
874 op = create_mean([1, 16, 16, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
875 assert support.is_operator_supported(op)
876 op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
877 assert not support.is_operator_supported(op)
Dwight Lidman95b279f2021-03-26 10:53:28 +0100878
879
880def test_mean_hw_product_avgpool():
881 op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
882 assert support.is_operator_supported(op)
883 op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
884 assert not support.is_operator_supported(op)