blob: 36213b7314441294f821a6fef56e812f0c311ab6 [file] [log] [blame]
Louis Verhaardfa2f92a2020-09-21 11:56:18 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for support_operators
Michael McGeagh37ded342020-10-01 15:37:44 +010019import numpy as np
20
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020021from ethosu.vela.data_type import DataType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from ethosu.vela.operation import ActivationFunction
Louis Verhaardaee5d752020-09-30 09:01:52 +020023from ethosu.vela.operation import Op
Michael McGeagh16895482020-12-14 15:51:20 +000024from ethosu.vela.operation import Padding
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025from ethosu.vela.supported_operators import SupportedOperators
26from ethosu.vela.tensor import create_const_tensor
Michael McGeagh37ded342020-10-01 15:37:44 +010027from ethosu.vela.tensor import QuantizationParameters
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020028from ethosu.vela.tensor import Tensor
29from ethosu.vela.test import testutil
30
31support = SupportedOperators()
32
33
Michael McGeagh65fd9982020-10-20 11:49:28 +010034def test_constraint_tens_no_dynamic():
35 # Tensors cannot be dynamic (no shape, not a scalar)
36 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020037 assert not support.is_operator_supported(op)
Michael McGeagh37ded342020-10-01 15:37:44 +010038
39
40def test_constraint_tens_defined_shape():
41 # Tensors cannot have None in them
Michael McGeagh1f951fc2020-10-14 09:30:02 +010042 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
Michael McGeagh37ded342020-10-01 15:37:44 +010043 assert not support.is_operator_supported(op)
44
45
Michael McGeagh65fd9982020-10-20 11:49:28 +010046def test_constraint_tens_output_scalar():
47 # Scalar output is not allowed at all:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010048 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
Michael McGeagh65fd9982020-10-20 11:49:28 +010049 op.ofm.values = 0.5
Michael McGeagh37ded342020-10-01 15:37:44 +010050 assert not support.is_operator_supported(op)
Michael McGeagh184b2502020-10-09 17:19:52 +010051
52
Michael McGeagh65fd9982020-10-20 11:49:28 +010053def test_constraint_tens_input_scalar():
Michael McGeagh184b2502020-10-09 17:19:52 +010054 # Shapeless input is allowed if its of a certain type:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010055 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
Michael McGeagh184b2502020-10-09 17:19:52 +010056 assert support.is_operator_supported(op)
Michael McGeagh37ded342020-10-01 15:37:44 +010057 # Invalid shapeless input due to op type:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010058 op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
Michael McGeagh65fd9982020-10-20 11:49:28 +010059 op.ifm.values = 0.5
Michael McGeagh37ded342020-10-01 15:37:44 +010060 assert not support.is_operator_supported(op)
61
62
63def test_constraint_tens_shape_size():
64 # Tensors cannot be > 4D
patrik.gustavssoneeb85152020-12-21 17:10:40 +000065 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
Michael McGeagh37ded342020-10-01 15:37:44 +010066 assert not support.is_operator_supported(op)
67
68
69def test_constraint_tens_dtype():
Michael McGeagh184b2502020-10-09 17:19:52 +010070 # Tensors can only be of type uint8, int8, int16 and int32
Michael McGeagh1f951fc2020-10-14 09:30:02 +010071 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.float32)
Michael McGeagh37ded342020-10-01 15:37:44 +010072 assert not support.is_operator_supported(op)
Michael McGeagh184b2502020-10-09 17:19:52 +010073
74
75def test_constraint_tens_int32_ops():
Michael McGeagh37ded342020-10-01 15:37:44 +010076 # For int32, only select op types are allowed:
Michael McGeagh1f951fc2020-10-14 09:30:02 +010077 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], datatype=DataType.int32)
Michael McGeagh37ded342020-10-01 15:37:44 +010078 assert support.is_operator_supported(op)
Michael McGeagh1f951fc2020-10-14 09:30:02 +010079 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
Michael McGeagh37ded342020-10-01 15:37:44 +010080 assert not support.is_operator_supported(op)
81
82
83def test_constraint_tens_dimension():
84 # Tensors can only have values in the inclusive range of 1-65535
Michael McGeagh1f951fc2020-10-14 09:30:02 +010085 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 0], [1, 8, 8, 65536])
Michael McGeagh37ded342020-10-01 15:37:44 +010086 assert not support.is_operator_supported(op)
87
88
Michael McGeagh184b2502020-10-09 17:19:52 +010089def test_constraint_tens_quant_none_check():
90 # Tensors must have quantization parameters
Michael McGeagh1f951fc2020-10-14 09:30:02 +010091 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
Michael McGeagh184b2502020-10-09 17:19:52 +010092 assert not support.is_operator_supported(op)
93
94
95def test_constraint_tens_quant_scale():
Louis Verhaard9a0cff12021-01-08 11:17:33 +010096 # Quantization scale cannot be infinite
Michael McGeagh184b2502020-10-09 17:19:52 +010097 qp = QuantizationParameters()
Michael McGeagh65fd9982020-10-20 11:49:28 +010098 qp.zero_point = 0
Michael McGeagh184b2502020-10-09 17:19:52 +010099 qp.scale_f32 = np.inf
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100100 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
Michael McGeagh184b2502020-10-09 17:19:52 +0100101 assert not support.is_operator_supported(op)
102
103
Dwight Lidmanc7187432020-11-16 17:40:46 +0100104def test_constraint_tens_quant_per_axis_not_supp():
105 # Quantization scale cannot be array-valued for elemwise ops
106 qp = QuantizationParameters()
107 qp.zero_point = np.zeros((1, 3))
108 qp.scale_f32 = np.ones((1, 3))
109 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
110 assert not support.is_operator_supported(op)
111
112
113def test_constraint_tens_quant_per_axis_is_supp():
114 op = testutil.create_op_with_quant_tensors(
115 Op.Conv2DBias, [1, 1, 1, 3], [1, 1, 1, 3], weights_shape=[1, 1, 1, 3], bias_shape=[1, 1, 1, 3]
116 )
117 op.attrs = {"stride_w": 1, "stride_h": 1}
118 assert support.is_operator_supported(op)
119 qp = QuantizationParameters()
120 qp.zero_point = np.zeros((1, 3))
121 qp.scale_f32 = np.ones((1, 3))
122 op.bias.quantization = qp
123 assert support.is_operator_supported(op)
124
125
Dwight Lidman0dd21c72020-11-24 13:45:50 +0100126def test_constraint_fc_output_2d_not_supp():
127 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
128 assert not support.is_operator_supported(op)
129 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
130 assert not support.is_operator_supported(op)
131 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
132 assert not support.is_operator_supported(op)
133
134
135def test_constraint_fc_output_2d_is_supp():
136 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
137 assert support.is_operator_supported(op)
138 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
139 assert support.is_operator_supported(op)
140
141
Michael McGeagh37ded342020-10-01 15:37:44 +0100142def test_constraint_faf():
143 # Fused activation functions, if set, must be a valid op type
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100144 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 op.activation = ActivationFunction(Op.Conv2D)
Michael McGeagh37ded342020-10-01 15:37:44 +0100146 assert not support.is_operator_supported(op)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100147
148
149def test_constraint_conv_pass():
150 # First test a simple conv passes
151 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
152 op.attrs = {"stride_w": 1, "stride_h": 1}
153 assert support.is_operator_supported(op)
154
155
156def test_constraint_stride_type():
157 # Stride width and height must be integer types
158 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
159 op.attrs = {"stride_w": 1.5, "stride_h": "1"}
160 assert not support.is_operator_supported(op)
161
162
163def test_constraint_stride_range():
164 # Stride width and height must lie within a certain range
165 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
166 op.attrs = {"stride_w": 0, "stride_h": 20}
167 assert not support.is_operator_supported(op)
168
169
170def test_constraint_dilation_type():
171 # Dilation width and height must be integer types
172 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
173 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
174 assert not support.is_operator_supported(op)
175
176
177def test_constraint_dilation_range():
178 # Dilation width and height must lie within a certain range
179 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
180 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20}
181 assert not support.is_operator_supported(op)
182
183
184def test_constraint_dilated_height_range():
185 # Dilated kernel height must lie within a certain range
186 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
187 op.attrs = {"stride_w": 1, "stride_h": 1}
188 assert not support.is_operator_supported(op)
189
190
191def test_constraint_dilated_product_range():
192 # Dilated kernel width x height must lie within a certain range
193 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
194 op.attrs = {"stride_w": 1, "stride_h": 1}
195 assert not support.is_operator_supported(op)
196
197
198def test_constraint_weights_type():
199 # Weight tensor must be 8-bit
200 op = testutil.create_op_with_quant_tensors(
201 Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
202 )
203 op.attrs = {"stride_w": 1, "stride_h": 1}
204 assert not support.is_operator_supported(op)
205
206
Michael McGeagh65fd9982020-10-20 11:49:28 +0100207def test_constraint_weights_const():
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100208 # Weight tensor cannot be non-const tensors
209 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
210 op.attrs = {"stride_w": 1, "stride_h": 1}
211 weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
Michael McGeagh65fd9982020-10-20 11:49:28 +0100212 weights.quantization = testutil.default_quant_params()
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100213 op.add_input_tensor(weights)
214 assert not support.is_operator_supported(op)
215
216
217def test_constraint_weights_limit():
218 # Sum of weights has a limit
219 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
220 op.attrs = {"stride_w": 1, "stride_h": 1}
221 op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
222 assert not support.is_operator_supported(op)
223
224
225def test_constraint_bias_type():
226 # Bias must have a certain datatype
227 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
228 op.attrs = {"stride_w": 1, "stride_h": 1}
229 bias = Tensor([1, 8, 8, 8], DataType.uint8, "bias")
230 op.add_input_tensor(bias)
231 assert not support.is_operator_supported(op)
232
233
234def test_constraint_bias_40bit():
235 # Bias must not exceed 40-bit
236 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
237 op.attrs = {"stride_w": 1, "stride_h": 1}
238 bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
Michael McGeagh65fd9982020-10-20 11:49:28 +0100239 bias.quant_values = np.array([0x01FF_FFFF_FFFF])
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100240 op.add_input_tensor(bias)
241 assert not support.is_operator_supported(op)
242
243
244def test_constraint_batch_size():
245 op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
246 op.attrs = {"stride_w": 1, "stride_h": 1}
247 assert not support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100248
249
250def test_constraint_quant_scale_inf():
Louis Verhaard9a0cff12021-01-08 11:17:33 +0100251 # Test handling IFM scale/OFM scale is infinite
Michael McGeagh65fd9982020-10-20 11:49:28 +0100252 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
Louis Verhaard9a0cff12021-01-08 11:17:33 +0100253 op.ifm.quantization.scale_f32 = np.float32(1e9)
254 op.ofm.quantization.scale_f32 = np.float32(1e-35)
255 assert not support.is_operator_supported(op)
256
257
258def test_constraint_ofm_scale_too_small():
259 # Tests handling of OFM scale < 1e-38
260 shp = [1, 10, 20, 16]
261 op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
262 assert support.is_operator_supported(op)
263 op.ofm.quantization.scale_f32 = 1e-43
Michael McGeagh65fd9982020-10-20 11:49:28 +0100264 assert not support.is_operator_supported(op)
265
266
267def test_constraint_depth_multiplier():
268 # Valid. Depth multiplier is 1 so no further constraints
269 op = testutil.create_op_with_quant_tensors(
270 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
271 )
272 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 1}
273 assert support.is_operator_supported(op)
274 # Invalid. Depth multiplier doesnt equal ofm channel
275 op = testutil.create_op_with_quant_tensors(
276 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]
277 )
278 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
279 assert not support.is_operator_supported(op)
280 # Valid. Depth multiplier is equal to ofm channel
281 op = testutil.create_op_with_quant_tensors(
282 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
283 )
284 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
285 assert support.is_operator_supported(op)
286
287
288def test_constraint_tconv_stride():
289 # Strides must be 2
290 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000291 op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100292 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
293 ifm.quantization = testutil.default_quant_params()
294 op.add_input_tensor(ifm)
295 assert not support.is_operator_supported(op)
296
297
298def test_constraint_tconv_same():
299 # Valid
300 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000301 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100302 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
303 ifm.quantization = testutil.default_quant_params()
304 op.add_input_tensor(ifm)
305 assert support.is_operator_supported(op)
306 # Invalid
307 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000308 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100309 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
310 ifm.quantization = testutil.default_quant_params()
311 op.add_input_tensor(ifm)
312 assert not support.is_operator_supported(op)
313
314
315def test_constraint_tconv_valid():
316 # Valid
317 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000318 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100319 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
320 ifm.quantization = testutil.default_quant_params()
321 op.add_input_tensor(ifm)
322 assert support.is_operator_supported(op)
323 # Invalid
324 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
Michael McGeagh16895482020-12-14 15:51:20 +0000325 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100326 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
327 ifm.quantization = testutil.default_quant_params()
328 op.add_input_tensor(ifm)
329 assert not support.is_operator_supported(op)
330
331
332def test_constraint_matching_in_out_types():
333 # Valid
334 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000335 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100336 assert support.is_operator_supported(op)
337 # Invalid. datatypes for ifm and ofm must match (default uint8)
338 op.ifm.dtype = DataType.int8
339 assert not support.is_operator_supported(op)
340
341
342def test_constraint_filter_type():
343 # Filter width/height must be integers
344 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000345 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100346 assert not support.is_operator_supported(op)
347
348
349def test_constraint_filter_range():
350 # Avg pool restrictions are dependent on padding:
351 # SAME padding restricts both W and H to max 8
352 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000353 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100354 assert not support.is_operator_supported(op)
355 # VALID padding limits are much larger
Michael McGeagh16895482020-12-14 15:51:20 +0000356 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100357 assert support.is_operator_supported(op)
358
359
360def test_constraint_filter_height_range_valid_pad():
361 # Avg pool restrictions are dependent on padding:
362 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000363 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100364 assert support.is_operator_supported(op)
365 # VALID padding restricts to 256 in filter height
366 op.attrs["filter_height"] = 257
367 assert not support.is_operator_supported(op)
368
369
370def test_constraint_filter_product_height_range_valid_pad():
371 # Avg pool restrictions are dependent on padding:
372 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000373 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.VALID}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100374 assert support.is_operator_supported(op)
375 # VALID padding restricts filter W x H to 256x256
376 op.attrs["filter_width"] = 257
377 assert not support.is_operator_supported(op)
378
379
380def test_constraint_filter_height_range():
381 # Max pool restrictions arent dependent on padding
382 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000383 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100384 assert support.is_operator_supported(op)
385 # Restricts to 256 in filter height
386 op.attrs["filter_height"] = 257
387 assert not support.is_operator_supported(op)
388 # Doesnt matter if SAME or VALID
Michael McGeagh16895482020-12-14 15:51:20 +0000389 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100390 assert not support.is_operator_supported(op)
391
392
393def test_constraint_filter_product_height_range():
394 # Max pool restrictions arent dependent on padding
395 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
Michael McGeagh16895482020-12-14 15:51:20 +0000396 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.SAME}
Michael McGeagh65fd9982020-10-20 11:49:28 +0100397 assert support.is_operator_supported(op)
398 # Restricts filter W x H to 256x256
399 op.attrs["filter_width"] = 257
400 assert not support.is_operator_supported(op)
401 # Doesnt matter if SAME or VALID
Michael McGeagh16895482020-12-14 15:51:20 +0000402 op.attrs["padding"] = Padding.VALID
Michael McGeagh65fd9982020-10-20 11:49:28 +0100403 assert not support.is_operator_supported(op)
404
405
406def test_constraint_resize():
407 # IFM W and H == 1
408 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
409 assert support.is_operator_supported(op)
410 # IFM == OFM
411 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
412 assert support.is_operator_supported(op)
413 # IFM x2 == OFM ; align_corners = False
414 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
415 assert support.is_operator_supported(op)
416 # IFM x2 -1 == OFM ; align_corners = True
417 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
418 op.attrs["align_corners"] = True
419 assert support.is_operator_supported(op)
420 # Invalid cases
421 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 20, 20, 8])
422 assert not support.is_operator_supported(op)
423 op.attrs["align_corners"] = True
424 assert not support.is_operator_supported(op)
425
426
427def test_constraint_matching_shapes():
428 # Softmax requires the ifm and ofm shapes to match
429 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
430 assert not support.is_operator_supported(op)
431 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
432 assert support.is_operator_supported(op)
433
434
Patrik Gustavsson2fa15882020-11-13 09:02:31 +0100435def test_constraint_beta_value_range():
436 # beta must be positive
437 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
438 op.attrs["beta"] = -1.0
439 assert not support.is_operator_supported(op)
440 op.attrs["beta"] = 0.0
441 assert support.is_operator_supported(op)
442
443
Michael McGeagh65fd9982020-10-20 11:49:28 +0100444def test_constraint_splitv_inferred():
445 # SplitV requires a maximum of one inferred shape (-1)
446 qp = testutil.default_quant_params()
447 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
448 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
449 op.add_input_tensor(sizes)
450 assert not support.is_operator_supported(op)
451 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
452 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
453 op.add_input_tensor(sizes)
454 assert support.is_operator_supported(op)
455
456
457def test_constraint_concat_pass():
458 # A working concat
459 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
460 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
461 ifm2.quantization = testutil.default_quant_params()
462 op.add_input_tensor(ifm2)
463 op.attrs["axis"] = 3
464 assert support.is_operator_supported(op)
465
466
467def test_constraint_axis_exists():
468 # Missing axis attribute
469 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
470 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
471 ifm2.quantization = testutil.default_quant_params()
472 op.add_input_tensor(ifm2)
473 assert not support.is_operator_supported(op)
474
475
476def test_constraint_axis_valid():
477 # Invalid axis attribute
478 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
479 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
480 ifm2.quantization = testutil.default_quant_params()
481 op.add_input_tensor(ifm2)
482 op.attrs["axis"] = 7
483 assert not support.is_operator_supported(op)
484
485
486def test_constraint_matching_dimensionality():
487 # Mismatching dimensionality: 4D+2D=4D
488 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
489 ifm2 = Tensor([1, 4], DataType.uint8, "in2")
490 ifm2.quantization = testutil.default_quant_params()
491 op.add_input_tensor(ifm2)
492 op.attrs["axis"] = 3
493 assert not support.is_operator_supported(op)
494
495
496def test_constraint_valid_dimensions():
497 # Mismatching dimension value:
498 # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
499 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
500 ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
501 ifm2.quantization = testutil.default_quant_params()
502 op.add_input_tensor(ifm2)
503 op.attrs["axis"] = 3
504 assert not support.is_operator_supported(op)
505
506
507def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
508 qp = testutil.default_quant_params()
509 in0 = Tensor(in_shape, DataType.uint8, "in")
510 in0.quantization = qp
511 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
512 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
513 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
514 out = Tensor(out_shape, DataType.uint8, "out")
515 out.quantization = qp
516 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
517 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
518
519
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100520def create_pad_op(
erik.andersson@arm.com7b676492021-01-18 14:23:12 +0100521 in_shape,
522 out_shape,
523 padding,
524 in_dtype=DataType.int8,
525 out_dtype=DataType.int8,
526 pad_dtype=DataType.int32,
527 pad_setting=Padding.VALID,
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100528):
529 qp = testutil.default_quant_params()
530 in0 = Tensor(in_shape, in_dtype, "in")
531 in0.quantization = qp
532 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
533 out = Tensor(out_shape, out_dtype, "out")
534 out.quantization = qp.clone()
535 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100536 conv_out_tens = Tensor(in_shape, in_dtype, "output")
537 conv_out_tens.quantization = qp.clone()
538 weight_tens = Tensor(in_shape, in_dtype, "weights")
539 weight_tens.values = np.zeros(weight_tens.shape)
540 weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
541 weight_tens.quantization = qp.clone()
erik.andersson@arm.com7b676492021-01-18 14:23:12 +0100542 bias_tens = Tensor(out_shape, pad_dtype, "biases")
543 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100544 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
erik.andersson@arm.com7b676492021-01-18 14:23:12 +0100545 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100546 conv2d_op.add_input_tensor(out)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100547 return op
548
549
550def test_constraint_pad_input_count():
551 # Incorrect number of input tensors (2)
552 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
553 assert support.is_operator_supported(op)
554 op.add_input_tensor(op.inputs[0].clone())
555 assert not support.is_operator_supported(op)
556
557
558def test_constraint_padded_dimensions():
559 # Incorrect padding dimensions, can only pad width and height
560 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
561 assert not support.is_operator_supported(op)
562
563
564def test_constraint_pad_shape():
565 # PAD operator must be of shape (4,2)
566 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
567 assert not support.is_operator_supported(op)
568
569
570def test_constraint_pad_none():
571 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],)
572 assert not support.is_operator_supported(op)
573
574
575def test_constraint_pad_dtype():
576 # PAD operator dtype should be int32 or int64
577 op = create_pad_op(
578 in_shape=[1, 1, 1, 1],
579 out_shape=[1, 3, 3, 1],
580 padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
581 pad_dtype=DataType.int16,
582 )
583 assert not support.is_operator_supported(op)
584
585
586def test_constraint_pad_consumer():
587 # PAD operator must be followed by a valid consumer with Padding.VALID attribute
erik.andersson@arm.com7b676492021-01-18 14:23:12 +0100588 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
589 assert support.is_operator_supported(op)
590 op = create_pad_op(
591 in_shape=[1, 1, 1, 1],
592 out_shape=[1, 3, 3, 1],
593 padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
594 pad_setting=Padding.SAME,
595 )
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100596 assert not support.is_operator_supported(op)
erik.andersson@arm.com7b676492021-01-18 14:23:12 +0100597 op_consumer = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100598 op.ofm.consumer_list = [op_consumer]
599 assert not support.is_operator_supported(op)
600 op_consumer = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
601 op_consumer.attrs = {
602 "stride_w": 2,
603 "stride_h": 2,
604 "filter_width": 2,
605 "filter_height": 2,
606 "padding": Padding.VALID,
607 }
608 op.ofm.consumer_list = [op_consumer]
609 assert not support.is_operator_supported(op)
610
611
Michael McGeagh65fd9982020-10-20 11:49:28 +0100612def create_strided_slice():
613 # Creates a valid strided slice operator with some valid inputs/outputs
614 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
615 op.attrs["begin_mask"] = 1
616 op.attrs["end_mask"] = 9
617 assert support.is_operator_supported(op)
618 return op
619
620
621def test_constraint_stridedslice_input_count():
622 # Wrong number of input tensors
623 op = create_strided_slice()
624 op.add_input_tensor(op.inputs[0].clone())
625 assert not support.is_operator_supported(op)
626
627
628def test_constraint_stridedslice_inputs_const():
629 # begin, end, stride values must not be None
630 op = create_strided_slice()
631 op.inputs[1].values = None
632 assert not support.is_operator_supported(op)
633 op = create_strided_slice()
634 op.inputs[2].values = None
635 assert not support.is_operator_supported(op)
636 op = create_strided_slice()
637 op.inputs[3].values = None
638 assert not support.is_operator_supported(op)
639
640
Michael McGeagh65fd9982020-10-20 11:49:28 +0100641def test_constraint_stridedslice_stride_values():
642 # Unsupported strides
643 op = create_strided_slice()
644 op.inputs[3].values = [1, 1, 2, 1]
645 assert not support.is_operator_supported(op)
646
647
648def test_constraint_ellipsis_mask():
649 # Unsupported ellipsis mask
650 op = create_strided_slice()
651 op.attrs["ellipsis_mask"] = 1
652 assert not support.is_operator_supported(op)
653
654
655def test_constraint_axis_masks():
656 op = create_strided_slice()
657 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
658 op.attrs["new_axis_mask"] = 2
659 assert support.is_operator_supported(op)
660 op = create_strided_slice()
661 op.attrs["shrink_axis_mask"] = 3
662 assert support.is_operator_supported(op)
663 # But setting both to non-zero is not supported
664 op.attrs["new_axis_mask"] = 2
665 assert not support.is_operator_supported(op)
666
667
668def test_constraint_slice_ranges():
669 # Examples where end offset <= begin offset
670 op = create_strided_slice()
671 op.inputs[1].values = [0, 7, 2, 0]
672 assert not support.is_operator_supported(op)
673 op = create_strided_slice()
674 op.inputs[2].values = [0, 7, 2, 0]
675 assert not support.is_operator_supported(op)
676 op = create_strided_slice()
677 op.attrs["begin_mask"] = 0
678 assert not support.is_operator_supported(op)
679 op = create_strided_slice()
680 op.attrs["end_mask"] = 0
681 assert not support.is_operator_supported(op)
682
683
684def test_constraint_matching_inputs_types():
685 # input data types must match (default is uint8)
686 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
687 op.ifm2.dtype = DataType.int8
688 assert not support.is_operator_supported(op)
689
690
691def test_constraint_matching_signed():
692 # signed inputs require output to also be signed
693 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
694 op.ofm.dtype = DataType.uint8
695 assert not support.is_operator_supported(op)
696
697
698def test_constraint_unsigned_valid():
699 # unsigned inputs require output to be either:
700 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
701 # the same (default uint8)
702 assert support.is_operator_supported(op)
703 op.ofm.dtype = DataType.int8
704 assert not support.is_operator_supported(op)
705 op.ofm.dtype = DataType.int16
706 assert not support.is_operator_supported(op)
707 # or int32
708 op.ofm.dtype = DataType.int32
709 assert support.is_operator_supported(op)
710
711
712def test_constraint_inputs_int32():
713 # both inputs must be type int32
714 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
715 assert not support.is_operator_supported(op)
716 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
717 assert support.is_operator_supported(op)
718 op.ifm2.dtype = DataType.int16
719 assert not support.is_operator_supported(op)
720
721
722def test_constraint_output_int32():
723 # output must be type int32
724 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
725 assert support.is_operator_supported(op)
726 op.ofm.dtype = DataType.int16
727 assert not support.is_operator_supported(op)
728
729
730def test_constraint_matching_quantization_parameters():
731 qp = QuantizationParameters()
732 qp.scale_f32 = np.float32(1.5)
733 qp.zero_point = 128
734 # valid - all matching (uses default quant params)
735 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
736 assert support.is_operator_supported(op)
737 # invalid - ifm mismatch ofm
738 op.ifm.quantization = qp
739 assert not support.is_operator_supported(op)
740 # invalid - ifm2 mismatch ofm
741 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
742 op.ifm2.quantization = qp
743 assert not support.is_operator_supported(op)
744 # invalid - both ifm and ifm2 mismatch ofm
745 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
746 op.ifm.quantization = qp
747 op.ifm2.quantization = qp
748 assert not support.is_operator_supported(op)
749 # valid - all matching
750 op.ofm.quantization = qp
751 assert support.is_operator_supported(op)
Erik Anderssonf27a8b62020-12-10 14:58:23 +0100752 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
753 assert support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100754
755
756def test_constraint_elemwise_batch_size():
757 # BINARY CASE
758 # Batch can be >1 if dims is <=2D
759 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
760 assert support.is_operator_supported(op)
761 # For dims >2D, batch must be 1
762 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
763 assert support.is_operator_supported(op)
764 # invalid case
765 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
766 assert not support.is_operator_supported(op)
767
768 # UNARY CASE
769 # Batch can be >1 if dims is <=2D
770 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
771 assert support.is_operator_supported(op)
772 # For dims >2D, batch must be 1
773 op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
774 assert support.is_operator_supported(op)
775 # invalid case
776 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
777 assert not support.is_operator_supported(op)
778
779
780def test_constraint_matching_either_shapes():
781 # BINARY CASE
782 # At least one ifm shape must match ofm's shape
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100783 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
Michael McGeagh65fd9982020-10-20 11:49:28 +0100784 assert support.is_operator_supported(op)
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100785 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
Michael McGeagh65fd9982020-10-20 11:49:28 +0100786 assert support.is_operator_supported(op)
787 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
788 assert not support.is_operator_supported(op)
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100789 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
790 assert not support.is_operator_supported(op)
791 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
792 assert not support.is_operator_supported(op)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100793
794 # UNARY CASE
795 # No second input so this is treated the same as requiring ifm shape to match ofm shape
796 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
797 assert support.is_operator_supported(op)
798 op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
799 assert not support.is_operator_supported(op)
800
801
Andreas Nevalainend059d8b2020-11-19 14:40:35 +0100802def test_constraint_broadcast_shapes():
803 # BINARY CASE
804 # Only allow broadcast to 1 dim, for 1 rank index
805 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
806 assert support.is_operator_supported(op)
807 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
808 assert support.is_operator_supported(op)
809 # Only allow broadcast to 1 dim, for 3 rank indexes
810 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
811 assert support.is_operator_supported(op)
812 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
813 assert support.is_operator_supported(op)
814 # One broadcast dim not 1
815 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
816 assert not support.is_operator_supported(op)
817 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
818 assert not support.is_operator_supported(op)
819 # OFM shape dim largest ifm/ifm2 shape dim
820 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
821 assert not support.is_operator_supported(op)
822 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
823 assert not support.is_operator_supported(op)
824 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
825 assert not support.is_operator_supported(op)
826 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
827 assert not support.is_operator_supported(op)
828
829
Michael McGeagh65fd9982020-10-20 11:49:28 +0100830def test_constraint_alpha_valid():
831 # Alpha cannot be negative
832 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
833 op.attrs["alpha"] = 0
834 assert support.is_operator_supported(op)
835 op.attrs["alpha"] = -1
836 assert not support.is_operator_supported(op)