blob: e65717a8f3f75fecf7cfdd963dfbf1dc34caf479 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for tflite support_operators
Tim Hall9cf63a32023-06-27 12:07:49 +010019from typing import List
20
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020021import numpy as np
Raul Farkas090f18a2023-01-24 16:29:06 +000022import pytest
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020023
24from ethosu.vela.data_type import DataType
25from ethosu.vela.operation import ActivationFunction
26from ethosu.vela.operation import Op
27from ethosu.vela.operation import Padding
28from ethosu.vela.tensor import create_const_tensor
29from ethosu.vela.tensor import QuantizationParameters
30from ethosu.vela.tensor import Tensor
31from ethosu.vela.test import testutil
32from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
33
34support = TFLiteSupportedOperators()
35
36
37def test_constraint_tens_dtype():
38 # Tensors can only be of type uint8, int8, int16 and int32
39 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.float32)
40 assert not support.is_operator_supported(op)
41
42
43def test_constraint_tens_int32_ops():
44 # For int32, only select op types are allowed:
45 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], datatype=DataType.int32)
46 assert support.is_operator_supported(op)
47 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
48 assert not support.is_operator_supported(op)
49
50
51def test_constraint_tens_dimension():
52 # Tensors can only have values in the inclusive range of 1-65535
53 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 0], [1, 8, 8, 65536])
54 assert not support.is_operator_supported(op)
55
56
57def test_constraint_tens_quant_per_axis_not_supp():
58 # Quantization scale cannot be array-valued for elemwise ops
59 qp = QuantizationParameters()
60 qp.zero_point = np.zeros((1, 3))
61 qp.scale_f32 = np.ones((1, 3))
62 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
63 assert not support.is_operator_supported(op)
64
65
66def test_constraint_tens_quant_per_axis_is_supp():
67 op = testutil.create_op_with_quant_tensors(
Johan Alfvénfaa4b782022-12-07 13:56:17 +010068 Op.Conv2DBias, [1, 1, 1, 3], [1, 1, 1, 3], weights_shape=[1, 1, 1, 3], bias_shape=[3]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020069 )
70 op.attrs = {"stride_w": 1, "stride_h": 1}
71 assert support.is_operator_supported(op)
72 qp = QuantizationParameters()
73 qp.zero_point = np.zeros((1, 3))
74 qp.scale_f32 = np.ones((1, 3))
75 op.bias.quantization = qp
76 assert support.is_operator_supported(op)
77
78
79def test_constraint_fc_output_2d_is_supp():
80 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
81 assert support.is_operator_supported(op)
82 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
83 assert support.is_operator_supported(op)
84
85
86def test_constraint_faf():
87 # Fused activation functions, if set, must be a valid op type
88 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
89 op.activation = ActivationFunction(Op.Conv2D)
90 assert not support.is_operator_supported(op)
91
92
93def test_constraint_faf_ofm_dtype():
94 # If fused activation function is present, OFM must be 8 or 16 bit
95 shp = [1, 8, 8, 8]
96 for dtype in [DataType.int8, DataType.uint8, DataType.int16, DataType.int32]:
97 op = testutil.create_elemwise_op(Op.Add, "op", shp, shp, shp, datatype=dtype)
98 op.activation = ActivationFunction(Op.Relu)
99 expected = dtype.size_in_bytes() <= 2
100 assert support.is_operator_supported(op) == expected, f"Data type: {dtype}"
101
102
103def test_constraint_conv_pass():
104 # First test a simple conv passes
105 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
106 op.attrs = {"stride_w": 1, "stride_h": 1}
107 assert support.is_operator_supported(op)
108
109
Raul Farkas090f18a2023-01-24 16:29:06 +0000110@pytest.mark.parametrize(
Raul Farkas3b64f062023-05-16 17:18:31 +0100111 "ifm_shape, stride_w, stride_h, supported",
Raul Farkas59b9ab92023-02-09 10:03:27 +0000112 [
Raul Farkas3b64f062023-05-16 17:18:31 +0100113 [[1, 8, 8, 8], 0, 20, False],
114 [[1, 8, 8, 8], 20, 0, False],
115 [[1, 8, 8, 8], 4, 3, True],
116 [[1, 8, 8, 8], 4, 5, False],
117 [[1, 8, 8, 8], 4, 9, False],
118 [[1, 8, 8, 8], 3, 3, True],
119 [[1, 8, 8, 8], 1, 1, True],
120 [[1, 8, 8, 8], 20, 2, False],
121 [[1, 8, 40, 8], 20, 2, True],
122 [[1, 8, 40, 8], 6, 3, True],
123 [[1, 8, 40, 8], 8, 1, True],
Raul Farkas59b9ab92023-02-09 10:03:27 +0000124 ],
Raul Farkas090f18a2023-01-24 16:29:06 +0000125)
Tim Hall9cf63a32023-06-27 12:07:49 +0100126def test_constraint_stride_range(ifm_shape: List[int], stride_w: int, stride_h: int, supported: bool):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200127 # Stride width and height must lie within a certain range
Raul Farkas3b64f062023-05-16 17:18:31 +0100128 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 8, 8, 8], [1, 1, 1, 1])
Raul Farkas090f18a2023-01-24 16:29:06 +0000129 op.attrs = {"stride_w": stride_w, "stride_h": stride_h}
130 assert support.is_operator_supported(op) == supported
Johan Alfvenafb56ae2023-10-27 13:08:21 +0200131 if not supported and stride_w > 0 and stride_h > 0:
132 # Test not supported but with ofm width and height = 1 -> supported
133 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 1, 1, 8], [1, 1, 1, 1])
134 op.attrs = {"stride_w": stride_w, "stride_h": stride_h}
135 assert support.is_operator_supported(op)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200136
137
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200138def test_constraint_dilated_height_range():
139 # Dilated kernel height must lie within a certain range
140 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
141 op.attrs = {"stride_w": 1, "stride_h": 1}
142 assert not support.is_operator_supported(op)
143
144
145def test_constraint_dilated_product_range():
146 # Dilated kernel width x height must lie within a certain range
147 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
148 op.attrs = {"stride_w": 1, "stride_h": 1}
149 assert not support.is_operator_supported(op)
150
151
152def test_constraint_weights_type():
153 # Weight tensor must be 8-bit
154 op = testutil.create_op_with_quant_tensors(
155 Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
156 )
157 op.attrs = {"stride_w": 1, "stride_h": 1}
158 assert not support.is_operator_supported(op)
159
160
161def test_constraint_weights_const():
162 # Weight tensor cannot be non-const tensors
163 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
164 op.attrs = {"stride_w": 1, "stride_h": 1}
165 weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
166 weights.quantization = testutil.default_quant_params()
167 op.add_input_tensor(weights)
168 assert not support.is_operator_supported(op)
169
170
171def test_constraint_weights_limit():
172 # Sum of weights has a limit
173 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
174 op.attrs = {"stride_w": 1, "stride_h": 1}
175 op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
176 assert not support.is_operator_supported(op)
177
178
179def test_constraint_bias_type():
180 # Bias must have a certain datatype
181 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
182 op.attrs = {"stride_w": 1, "stride_h": 1}
183 bias = Tensor([1, 8, 8, 8], DataType.uint8, "bias")
184 op.add_input_tensor(bias)
185 assert not support.is_operator_supported(op)
186
187
188def test_constraint_bias_40bit():
189 # Bias must not exceed 40-bit
190 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
191 op.attrs = {"stride_w": 1, "stride_h": 1}
192 bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
193 bias.values = np.array([0x01FF_FFFF_FFFF])
194 op.add_input_tensor(bias)
195 assert not support.is_operator_supported(op)
196
197
198def test_constraint_batch_size():
199 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
200 op.attrs = {"stride_w": 1, "stride_h": 1}
201 assert not support.is_operator_supported(op)
202
203
204def test_constraint_depth_multiplier():
205 # Valid. Depth multiplier is 1 so no further constraints
206 op = testutil.create_op_with_quant_tensors(
207 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
208 )
209 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 1}
210 assert support.is_operator_supported(op)
211 # Invalid. Depth multiplier doesnt equal ofm channel
212 op = testutil.create_op_with_quant_tensors(
213 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]
214 )
215 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
216 assert not support.is_operator_supported(op)
217 # Valid. Depth multiplier is equal to ofm channel
218 op = testutil.create_op_with_quant_tensors(
219 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
220 )
221 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
222 assert support.is_operator_supported(op)
223
224
225def test_constraint_tconv_stride():
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200226 # Valid 2x2
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200227 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200228 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
229 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
230 ifm.quantization = testutil.default_quant_params()
231 op.add_input_tensor(ifm)
232 assert support.is_operator_supported(op)
233 # Valid 1x1
234 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200235 op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
236 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
237 ifm.quantization = testutil.default_quant_params()
238 op.add_input_tensor(ifm)
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200239 assert support.is_operator_supported(op)
240 # Valid 2x1 (WxH) ifm h and kernel h = 1
241 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 2, 1], weights_shape=[1, 1, 1, 1])
242 op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
243 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
244 ifm.quantization = testutil.default_quant_params()
245 op.add_input_tensor(ifm)
246 assert support.is_operator_supported(op)
247 # Invalid 2x1 (WxH) ifm h = 2 and kernel h = 1
248 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 2, 1], weights_shape=[1, 1, 1, 1])
249 op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
250 ifm = Tensor([1, 2, 1, 1], DataType.uint8, "ifm")
251 ifm.quantization = testutil.default_quant_params()
252 op.add_input_tensor(ifm)
253 assert not support.is_operator_supported(op)
254 # Invalid 2x1 (WxH) ifm h = 1 and kernel h = 2
255 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 2, 1, 1])
256 op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
257 ifm = Tensor([1, 2, 1, 1], DataType.uint8, "ifm")
258 ifm.quantization = testutil.default_quant_params()
259 op.add_input_tensor(ifm)
260 assert not support.is_operator_supported(op)
261 # Invalid 1x2 (WxH)
262 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
263 op.attrs = {"stride_w": 1, "stride_h": 2, "padding": Padding.SAME}
264 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
265 ifm.quantization = testutil.default_quant_params()
266 op.add_input_tensor(ifm)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200267 assert not support.is_operator_supported(op)
268
269
270def test_constraint_tconv_same():
271 # Valid
272 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
273 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
274 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
275 ifm.quantization = testutil.default_quant_params()
276 op.add_input_tensor(ifm)
277 assert support.is_operator_supported(op)
278 # Invalid
279 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
280 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
281 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
282 ifm.quantization = testutil.default_quant_params()
283 op.add_input_tensor(ifm)
284 assert not support.is_operator_supported(op)
285
286
287def test_constraint_tconv_valid():
288 # Valid
289 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
290 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
291 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
292 ifm.quantization = testutil.default_quant_params()
293 op.add_input_tensor(ifm)
294 assert support.is_operator_supported(op)
295 # Invalid
296 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
297 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
298 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
299 ifm.quantization = testutil.default_quant_params()
300 op.add_input_tensor(ifm)
301 assert not support.is_operator_supported(op)
302
303
304def test_constraint_filter_range():
305 # Avg pool restrictions are dependent on padding:
306 # SAME padding restricts both W and H to max 8
307 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
308 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": Padding.SAME}
309 assert not support.is_operator_supported(op)
310 # VALID padding limits are much larger
311 op.attrs["padding"] = Padding.VALID
312 assert support.is_operator_supported(op)
313
314
315def test_constraint_filter_height_range_valid_pad():
316 # Avg pool restrictions are dependent on padding:
317 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
318 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.VALID}
319 assert support.is_operator_supported(op)
320 # VALID padding restricts to 256 in filter height
321 op.attrs["filter_height"] = 257
322 assert not support.is_operator_supported(op)
323
324
325def test_constraint_filter_product_height_range_valid_pad():
326 # Avg pool restrictions are dependent on padding:
327 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
328 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.VALID}
329 assert support.is_operator_supported(op)
330 # VALID padding restricts filter W x H to 256x256
331 op.attrs["filter_width"] = 257
332 assert not support.is_operator_supported(op)
333
334
335def test_constraint_filter_height_range():
336 # Max pool restrictions arent dependent on padding
337 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
338 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.SAME}
339 assert support.is_operator_supported(op)
340 # Restricts to 256 in filter height
341 op.attrs["filter_height"] = 257
342 assert not support.is_operator_supported(op)
343 # Doesnt matter if SAME or VALID
344 op.attrs["padding"] = Padding.VALID
345 assert not support.is_operator_supported(op)
346
347
348def test_constraint_filter_product_height_range():
349 # Max pool restrictions arent dependent on padding
350 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
351 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.SAME}
352 assert support.is_operator_supported(op)
353 # Restricts filter W x H to 256x256
354 op.attrs["filter_width"] = 257
355 assert not support.is_operator_supported(op)
356 # Doesnt matter if SAME or VALID
357 op.attrs["padding"] = Padding.VALID
358 assert not support.is_operator_supported(op)
359
360
Tim Hall885033b2022-07-21 11:46:03 +0100361def test_constraint_resize():
362 for resize_op in Op.op_set(Op.is_resize_op):
363 # IFM W and H == 1
364 op = testutil.create_op_with_quant_tensors(resize_op, [1, 1, 1, 8], [1, 8, 8, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000365 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
Tim Hall885033b2022-07-21 11:46:03 +0100366 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100367
Tim Hall885033b2022-07-21 11:46:03 +0100368 # IFM == OFM
369 op = testutil.create_op_with_quant_tensors(resize_op, [1, 8, 8, 8], [1, 8, 8, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000370 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
Tim Hall885033b2022-07-21 11:46:03 +0100371 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100372
Tim Hall885033b2022-07-21 11:46:03 +0100373 # IFM x2 == OFM ; align_corners = False
374 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000375 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
Tim Hall885033b2022-07-21 11:46:03 +0100376 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100377
Tim Hall885033b2022-07-21 11:46:03 +0100378 # IFM x4 == OFM ; align_corners = False
379 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 16, 16, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000380 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16]))
Tim Hall885033b2022-07-21 11:46:03 +0100381 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100382
Tim Hall885033b2022-07-21 11:46:03 +0100383 # IFM x8 == OFM ; align_corners = False
384 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 32, 32, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000385 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32]))
Tim Hall885033b2022-07-21 11:46:03 +0100386 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100387
Tim Hall885033b2022-07-21 11:46:03 +0100388 # IFM -1 x2 == OFM -1 ; align_corners = True
389 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 7, 7, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000390 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
Tim Hall885033b2022-07-21 11:46:03 +0100391 op.attrs["align_corners"] = True
392 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100393
Tim Hall885033b2022-07-21 11:46:03 +0100394 # IFM -1 x4 == OFM -1 ; align_corners = True
395 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 13, 13, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000396 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13]))
Tim Hall885033b2022-07-21 11:46:03 +0100397 op.attrs["align_corners"] = True
398 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100399
Tim Hall885033b2022-07-21 11:46:03 +0100400 # IFM -1 x8 == OFM -1 ; align_corners = True
401 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 25, 25, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000402 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25]))
Tim Hall885033b2022-07-21 11:46:03 +0100403 op.attrs["align_corners"] = True
404 assert support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100405
Tim Hall885033b2022-07-21 11:46:03 +0100406 # Invalid case - upscale size
407 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 17, 17, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000408 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17]))
Tim Hall885033b2022-07-21 11:46:03 +0100409 assert not support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100410
Tim Hall885033b2022-07-21 11:46:03 +0100411 # Invalid case - upscale size with align corners
412 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 15, 15, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000413 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15]))
Tim Hall885033b2022-07-21 11:46:03 +0100414 op.attrs["align_corners"] = True
415 assert not support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100416
417
Tim Hall885033b2022-07-21 11:46:03 +0100418def test_constraint_resize_size():
419 for resize_op in Op.op_set(Op.is_resize_op):
420 # Invalid case - size != ofm size
421 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000422 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
Tim Hall885033b2022-07-21 11:46:03 +0100423 assert not support.is_operator_supported(op)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200424
425
Tim Hall885033b2022-07-21 11:46:03 +0100426def test_constraint_resize_attrs():
427 for resize_op in Op.op_set(Op.is_resize_op):
428 # Invalid case - both align corners and half-pixel centers
429 op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
Tim Hall3b1578e2023-01-13 17:57:25 +0000430 op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
Tim Hall885033b2022-07-21 11:46:03 +0100431 op.attrs["align_corners"] = True
432 op.attrs["half_pixel_centers"] = True
433 assert not support.is_operator_supported(op)
Tim Hall47c76362022-07-18 21:26:47 +0100434
435
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200436def test_constraint_concat_pass():
437 # A working concat
438 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
439 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
440 ifm2.quantization = testutil.default_quant_params()
441 op.add_input_tensor(ifm2)
442 op.attrs["axis"] = 3
443 assert support.is_operator_supported(op)
444
445
446def create_pad_op(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200447 in_shape,
448 out_shape,
449 padding,
450 in_dtype=DataType.int8,
451 out_dtype=DataType.int8,
452 pad_dtype=DataType.int32,
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200453):
454 qp = testutil.default_quant_params()
455 in0 = Tensor(in_shape, in_dtype, "in")
456 in0.quantization = qp
Tim Hall3b1578e2023-01-13 17:57:25 +0000457 shape = [] if padding == [] else list(np.shape(padding))
458 pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200459 out = Tensor(out_shape, out_dtype, "out")
460 out.quantization = qp.clone()
461 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
462 return op
463
464
465def test_constraint_padded_dimensions():
466 # Incorrect padding dimensions, can only pad width and height
Jonas Ohlssond8575072022-03-30 10:30:25 +0200467 op = create_pad_op(
468 in_shape=[1, 1, 1, 1],
469 out_shape=[1, 3, 3, 1],
470 padding=[[1, 1], [1, 1], [1, 1], [0, 0]],
471 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200472 assert not support.is_operator_supported(op)
Jonas Ohlssond8575072022-03-30 10:30:25 +0200473 op = create_pad_op(
474 in_shape=[1, 1, 1, 1],
475 out_shape=[1, 3, 3, 1],
476 padding=[[1, 1], [1, 1], [0, 0]],
477 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200478 assert support.is_operator_supported(op)
Jonas Ohlssond8575072022-03-30 10:30:25 +0200479 op = create_pad_op(
480 in_shape=[1, 1, 1, 1],
481 out_shape=[1, 3, 3, 1],
482 padding=[[1, 1], [1, 1], [0, 1]],
483 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200484 assert not support.is_operator_supported(op)
485
486
487def test_constraint_pad_shape():
488 # PAD operator must be of shape (3,2) or (4,2)
489 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]])
490 assert support.is_operator_supported(op)
Jonas Ohlssond8575072022-03-30 10:30:25 +0200491 op = create_pad_op(
492 in_shape=[1, 1, 1, 1],
493 out_shape=[1, 3, 3, 1],
494 padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
495 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200496 assert not support.is_operator_supported(op)
497
498
499def test_constraint_pad_none():
Jonas Ohlssond8575072022-03-30 10:30:25 +0200500 op = create_pad_op(
501 in_shape=[1, 1, 1, 1],
502 out_shape=[1, 3, 3, 1],
503 padding=[],
504 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200505 assert not support.is_operator_supported(op)
506
507
508def test_constraint_pad_dtype():
509 # PAD operator dtype should be int32 or int64
510 op = create_pad_op(
511 in_shape=[1, 1, 1, 1],
512 out_shape=[1, 3, 3, 1],
513 padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
514 pad_dtype=DataType.int16,
515 )
516 assert not support.is_operator_supported(op)
517
518
519def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
520 qp = testutil.default_quant_params()
521 in0 = Tensor(in_shape, DataType.uint8, "in")
522 in0.quantization = qp
523 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
524 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
525 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
526 out = Tensor(out_shape, DataType.uint8, "out")
527 out.quantization = qp
528 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
529 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
530
531
532def create_strided_slice():
533 # Creates a valid strided slice operator with some valid inputs/outputs
534 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
535 op.attrs["begin_mask"] = 1
536 op.attrs["end_mask"] = 9
Rickard Bolinb37a81b2023-09-29 12:48:29 +0000537 op.attrs["offset"] = False
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200538 assert support.is_operator_supported(op)
539 return op
540
541
542def test_constraint_stridedslice_stride_values():
543 # Unsupported strides
544 op = create_strided_slice()
545 op.inputs[3].values = [1, 1, 2, 1]
546 assert not support.is_operator_supported(op)
547
548
Rickard Bolinb37a81b2023-09-29 12:48:29 +0000549def test_constraint_stridedslice_offset_false():
550 # Offset attribute must be False
551 op = create_strided_slice()
552 op.attrs["offset"] = True
553 assert not support.is_operator_supported(op)
554
555
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200556def test_constraint_inputs_int32():
557 # both inputs must be type int32
558 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
559 assert not support.is_operator_supported(op)
560 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
561 assert support.is_operator_supported(op)
562 op.ifm2.dtype = DataType.int16
563 assert not support.is_operator_supported(op)
564
565
566def test_constraint_output_int32():
567 # output must be type int32
568 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
569 assert support.is_operator_supported(op)
570 op.ofm.dtype = DataType.int16
571 assert not support.is_operator_supported(op)
572
573
574def test_constraint_matching_quantization_parameters():
575 qp = QuantizationParameters()
576 qp.scale_f32 = np.float32(1.5)
577 qp.zero_point = 128
578 # valid - all matching (uses default quant params)
579 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
580 assert support.is_operator_supported(op)
581 # invalid - ifm mismatch ofm
582 op.ifm.quantization = qp
583 assert not support.is_operator_supported(op)
584 # invalid - ifm2 mismatch ofm
585 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
586 op.ifm2.quantization = qp
587 assert not support.is_operator_supported(op)
588 # invalid - both ifm and ifm2 mismatch ofm
589 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
590 op.ifm.quantization = qp
591 op.ifm2.quantization = qp
592 assert not support.is_operator_supported(op)
593 # valid - all matching
594 op.ofm.quantization = qp
595 assert support.is_operator_supported(op)
596 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
597 assert support.is_operator_supported(op)
598
599
600def test_constraint_elemwise_batch_size():
601 # BINARY CASE
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200602 # Batch can be >1 if dims is <=3D
603 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200604 assert support.is_operator_supported(op)
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200605 # For dims >3D, batch must be 1
606 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200607 assert support.is_operator_supported(op)
608 # invalid case
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200609 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200610 assert not support.is_operator_supported(op)
611
612 # UNARY CASE
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200613 # Batch can be >1 if dims is <=3D
614 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200615 assert support.is_operator_supported(op)
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200616 # For dims >3D, batch must be 1
617 op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2, 2], None, [1, 2, 2, 2], datatype=DataType.int32)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200618 assert support.is_operator_supported(op)
619 # invalid case
Fredrik Svedberg88d5b122022-09-16 16:24:55 +0200620 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2, 2], None, [2, 2, 2, 2], datatype=DataType.int32)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200621 assert not support.is_operator_supported(op)
622
623
624def test_constraint_broadcast_shapes():
625 # BINARY CASE
626 # Only allow broadcast to 1 dim, for 1 rank index
627 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
628 assert support.is_operator_supported(op)
629 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
630 assert support.is_operator_supported(op)
631 # Only allow broadcast to 1 dim, for 3 rank indexes
632 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
633 assert support.is_operator_supported(op)
634 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
635 assert support.is_operator_supported(op)
636 # One broadcast dim not 1
637 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
638 assert not support.is_operator_supported(op)
639 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
640 assert not support.is_operator_supported(op)
641 # OFM shape dim largest ifm/ifm2 shape dim
642 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
643 assert not support.is_operator_supported(op)
644 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
645 assert not support.is_operator_supported(op)
646 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
647 assert not support.is_operator_supported(op)
648 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
649 assert not support.is_operator_supported(op)
650
651
652def create_mean(input_shape, output_shape, axis, datatype, attrs):
653 ifm = Tensor(input_shape, datatype, "in")
654 ifm.quantization = testutil.default_quant_params()
655 ofm = Tensor(output_shape, datatype, "out")
656 ofm.quantization = testutil.default_quant_params()
657 if type(axis) is list:
Tim Hall3b1578e2023-01-13 17:57:25 +0000658 indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200659 elif type(axis) is int:
Tim Hall3b1578e2023-01-13 17:57:25 +0000660 indices = create_const_tensor("indices", [], DataType.int32, axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200661 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
662 return op
663
664
665def test_mean_hw_product():
Alexander Hansson90c34b52023-05-31 15:03:03 +0000666 # max kernel size checks
667 op = create_mean([1, 4096, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {})
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200668 assert support.is_operator_supported(op)
Alexander Hansson90c34b52023-05-31 15:03:03 +0000669 op = create_mean([1, 4097, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {})
670 assert not support.is_operator_supported(op)
671
672 op = create_mean([1, 2048, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {})
673 assert support.is_operator_supported(op)
674 op = create_mean([1, 2049, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {})
675 assert not support.is_operator_supported(op)
676
677 op = create_mean([1, 16, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {})
678 assert support.is_operator_supported(op)
679 op = create_mean([1, 17, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {})
680 assert not support.is_operator_supported(op)
681
682 # h > 4096 is OK but w > 4096 is not
683 op = create_mean([1, 4097, 10, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {"keep_dims": True})
684 assert support.is_operator_supported(op)
685 op = create_mean([1, 10, 4097, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {"keep_dims": True})
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200686 assert not support.is_operator_supported(op)
687
688
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200689def test_lstm_support():
690 # Test valid configuration
691 op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
692 assert support.is_operator_supported(op)
693 # Test CIFG not supported
694 input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5]
695 op.inputs[1] = None
696 assert not support.is_operator_supported(op)
697 op.inputs[1] = input_to_input_weights
698 op.inputs[5] = None
699 assert not support.is_operator_supported(op)
700 op.inputs[5] = recurrent_to_input_weights
701 # Test Peephole not supported
702 op.inputs[9] = input_to_input_weights
703 assert not support.is_operator_supported(op)
704 op.inputs[9] = None
705 op.inputs[10] = input_to_input_weights
706 assert not support.is_operator_supported(op)
707 op.inputs[10] = None
708 op.inputs[11] = input_to_input_weights
709 assert not support.is_operator_supported(op)
710 op.inputs[11] = None
711 # Test Projection not supported
712 op.inputs[16] = input_to_input_weights
713 assert not support.is_operator_supported(op)
714 op.inputs[16] = None
715 op.inputs[17] = input_to_input_weights
716 assert not support.is_operator_supported(op)
717 op.inputs[17] = None
718 # Test Normalisation not supported
719 op.inputs[20] = input_to_input_weights
720 assert not support.is_operator_supported(op)
721 op.inputs[20] = None
722 op.inputs[21] = input_to_input_weights
723 assert not support.is_operator_supported(op)
724 op.inputs[21] = None
725 op.inputs[22] = input_to_input_weights
726 assert not support.is_operator_supported(op)
727 op.inputs[22] = None
728 op.inputs[23] = input_to_input_weights
729 assert not support.is_operator_supported(op)
730 op.inputs[23] = None
731 # Test restored valid configuration
732 assert support.is_operator_supported(op)
Johan Alfven8e525ca2023-05-07 13:12:37 +0200733
734
735def test_rsqrt_support():
736 # Test supported op (int8)
737 op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
738 assert support.is_operator_supported(op)
739 # Test not supported op (uint8)
740 op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint8)
741 assert not support.is_operator_supported(op)
742 # Test not supported op (int16)
743 op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
744 assert not support.is_operator_supported(op)
Johan Alfven85b77902023-06-15 09:24:01 +0200745
746
747def test_constraint_slice_inputs_const():
748 # Begin and Size tensor cannot be non-const tensors
749 # Test not supported op
750 ifm = Tensor([3, 1, 256], DataType.int8, "in")
751 begin = Tensor([3], DataType.int32, "begin")
752 size = Tensor([3], DataType.int32, "size")
753 ofm = Tensor([1, 1, 256], DataType.int8, "size")
754 op = testutil.create_op(Op.Slice, [ifm, begin, size], ofm)
755 assert not support.is_operator_supported(op)
756 # Test supported op
757 begin = create_const_tensor("begin", [3], DataType.int32, [0, 0, 0])
758 size = create_const_tensor("size", [3], DataType.int32, [2, 1, 256])
759 op.set_input_tensor(begin, 1)
760 op.set_input_tensor(begin, 2)
761 assert support.is_operator_supported(op)
Johan Alfvena8fda882023-10-28 16:04:46 +0200762
763
764def test_constraint_transpose():
765 # Test supported op IFM rank 2
766 ifm = Tensor([2, 4], DataType.int8, "ifm")
767 perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
768 ofm = Tensor([4, 2], DataType.int8, "ofm")
769 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
770 assert support.is_operator_supported(op)
771 # Test supported op IFM rank 3
772 ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
773 perm = create_const_tensor("perm", [3], DataType.int32, [1, 0, 2])
774 ofm = Tensor([4, 2, 6], DataType.int8, "ofm")
775 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
776 assert support.is_operator_supported(op)
777 ifm = Tensor([1, 4, 6], DataType.int8, "ifm")
778 perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1])
779 ofm = Tensor([1, 6, 4], DataType.int8, "ofm")
780 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
781 assert support.is_operator_supported(op)
782 ifm = Tensor([2, 1, 6], DataType.int8, "ifm")
783 perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0])
784 ofm = Tensor([6, 1, 2], DataType.int8, "ofm")
785 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
786 assert support.is_operator_supported(op)
787 # Test supported op IFM rank 4
788 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
789 perm = create_const_tensor("perm", [4], DataType.int32, [0, 2, 1, 3])
790 ofm = Tensor([1, 4, 2, 6], DataType.int8, "ofm")
791 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
792 assert support.is_operator_supported(op)
793 ifm = Tensor([1, 1, 4, 6], DataType.int8, "ifm")
794 perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2])
795 ofm = Tensor([1, 1, 6, 4], DataType.int8, "ofm")
796 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
797 assert support.is_operator_supported(op)
798 ifm = Tensor([1, 2, 1, 6], DataType.int8, "ifm")
799 perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1])
800 ofm = Tensor([1, 6, 1, 2], DataType.int8, "ofm")
801 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
802 assert support.is_operator_supported(op)
803 # Test not supported op IFM rank 3
804 ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
805 perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1])
806 ofm = Tensor([2, 6, 4], DataType.int8, "ofm")
807 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
808 assert not support.is_operator_supported(op)
809 ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
810 perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0])
811 ofm = Tensor([6, 2, 2], DataType.int8, "ofm")
812 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
813 assert not support.is_operator_supported(op)
814 # Test not supported op IFM rank 4
815 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
816 perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2])
817 ofm = Tensor([1, 2, 6, 4], DataType.int8, "ofm")
818 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
819 assert not support.is_operator_supported(op)
820 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
821 perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1])
822 ofm = Tensor([1, 6, 4, 2], DataType.int8, "ofm")
823 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
824 assert not support.is_operator_supported(op)
825 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
826 perm = create_const_tensor("perm", [4], DataType.int32, [1, 0, 2, 3])
827 ofm = Tensor([2, 1, 4, 6], DataType.int8, "ofm")
828 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
829 assert not support.is_operator_supported(op)
830 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
831 perm = create_const_tensor("perm", [4], DataType.int32, [2, 1, 0, 3])
832 ofm = Tensor([4, 2, 1, 6], DataType.int8, "ofm")
833 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
834 assert not support.is_operator_supported(op)
835 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
836 perm = create_const_tensor("perm", [4], DataType.int32, [3, 1, 2, 0])
837 ofm = Tensor([6, 2, 4, 1], DataType.int8, "ofm")
838 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
839 assert not support.is_operator_supported(op)
840 ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
841 perm = create_const_tensor("perm", [4], DataType.int32, [3, 2, 1, 0])
842 ofm = Tensor([6, 4, 2, 1], DataType.int8, "ofm")
843 op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
844 assert not support.is_operator_supported(op)